testsuite.py

changeset 62
f0a6bf48b05e
parent 47
4da025d0b283
child 63
8949af6a4279
--- a/testsuite.py	Wed Jun 05 00:33:50 2019 +0300
+++ b/testsuite.py	Sat Jun 08 01:32:25 2019 +0300
@@ -1,21 +1,46 @@
 from warnings import warn
 
-def report_element(bad_object, type, error_name, args):
-    return {
-        'type': type,
-        'object': bad_object,
-        'name': error_name,
-        'args': args,
-    }
+class ProblemType:
+    severities = ['error', 'notice'] # in descending order
+    def __init__(self, name, severity, message):
+        if severity not in ProblemType.severities:
+            raise ValueError(str.format(
+                'bad severity {severity!r}',
+                severity = severity,
+           ))
+        self.name = name
+        self.severity = severity
+        self.message = message
+    def __call__(self, bad_object, **args):
+        return Problem(
+            problem_class = self,
+            bad_object = bad_object,
+            **args,
+        )
 
-def warning(bad_object, error_name, **args):
-    return report_element(bad_object, 'warning', error_name, args)
+class Problem:
+    def __init__(self, problem_class, bad_object, **args):
+        self.problem_class = problem_class
+        self.severity = problem_class.severity
+        self.object = bad_object
+        self.args = args
+    def __str__(self):
+        if callable(self.problem_class.message):
+            return self.problem_class.message(**self.args)
+        else:
+            return self.problem_class.message
 
-def error(bad_object, error_name, **args):
-    return report_element(bad_object, 'error', error_name, args)
+def problem_type(problem_name, **args):
+    def wrapper(function):
+        if not hasattr(function, 'ldcheck_problem_types'):
+            function.ldcheck_problem_types = {}
+        new_type = ProblemType(name = problem_name, **args)
+        function.ldcheck_problem_types[problem_name] = new_type
+        return function
+    return wrapper
 
-def notice(bad_object, error_name, **args):
-    return report_element(bad_object, 'notice', error_name, args)
+def report_problem(problem_name, *, bad_object, **args):
+    return {'type': problem_name, 'bad-object': bad_object, 'args': args}
 
 def name_of_package(package):
     if isinstance(package, tuple):
@@ -34,51 +59,34 @@
         for result in walk_packages(tests.__path__)
     )
 
-def do_manifest_integrity_checks(test_suite, module):
-    '''
-        Runs integrity checks on a given module's manifest.
-    '''
-    def check_for_extra_keys():
-        extra_keys = module.manifest.keys() - test_suite.keys()
-        if extra_keys:
-            warn(str.format(
-                '{}: extra keys in manifest: {}',
-                module.__name__,
-                ', '.join(map(str, extra_keys))
-            ))
-    def check_for_manifest_duplicates():
-        for key in test_suite.keys():
-            duplicates = module.manifest[key].keys() & test_suite[key].keys()
-            if duplicates:
-                warn(str.format(
-                    '{}: redefined {} in manifests: {}',
-                    module.__name__,
-                    key,
-                    duplicates,
-                ))
-    check_for_extra_keys()
-    check_for_manifest_duplicates()
-
 def load_tests():
     '''
         Imports test modules and combines their manifests into a test suite.
     '''
-    test_suite = {'tests': {}, 'messages': {}}
+    test_suite = {'tests': []}
     for module_name in test_discovery():
         from importlib import import_module
         module = import_module(module_name)
         if hasattr(module, 'manifest'):
-            do_manifest_integrity_checks(test_suite, module)
             # Merge the data from the manifest
-            for key in module.manifest.keys() & test_suite.keys():
-                test_suite[key].update(module.manifest[key])
+            test_suite['tests'] += module.manifest['tests']
         else:
             warn(str.format('Module {} does not have a manifest', module_name))
+    test_suite['tests'].sort(key = lambda f: f.__name__)
     return test_suite
 
 def problem_key(problem):
-    problem_hierarchy = ['error', 'warning', 'notice']
-    return (problem_hierarchy.index(problem['type']), problem['line-number'])
+    rank = ProblemType.severities.index(problem.severity) # sort by severity
+    return (rank, problem.line_number)
+
+def build_problem(test_function, problem_params):
+    problem_name = problem_params['type']
+    problem_type = test_function.ldcheck_problem_types[problem_name]
+    problem_object = problem_type(
+        bad_object = problem_params['bad-object'],
+        **problem_params['args'],
+    )
+    return problem_object
 
 def check_model(model, test_suite = None):
     if not test_suite:
@@ -88,36 +96,38 @@
         element: (i, i + 1)
         for i, element in enumerate(model.body)
     }
-    for test_name, test_function in test_suite['tests'].items():
-        for problem in test_function(model):
-            problem['body-index'], problem['line-number'] \
-                = line_numbers[problem['object']]
-            del problem['object']
+    for test_function in test_suite['tests']:
+        for problem_params in test_function(model):
+            problem = build_problem(test_function, problem_params)
+            # add line numbers to the problem
+            problem.body_index, problem.line_number \
+                = line_numbers[problem.object]
+            problem.object = None
             problems.append(problem)
     return {
         'passed': not any(
-            problem['type'] == 'error'
+            problem.severity == 'error'
             for problem in problems
         ),
         'problems': sorted(problems, key = problem_key),
     }
 
 def problem_text(problem, test_suite):
-    message = test_suite['messages'][problem['name']]
+    message = problem.problem_class.message
     if callable(message):
-        message = message(**problem['args'])
+        message = message(**problem.args)
     return message
 
 def format_report_html(report, model, test_suite):
     messages = []
     for problem in report['problems']:
-        ldraw_code = model.body[problem['body-index']].textual_representation()
+        ldraw_code = model.body[problem.body_index].textual_representation()
         message = str.format(
             '<li class="{problem_type}">{model_name}:{line_number}:'
             '{problem_type}: {message}<br />{ldraw_code}</li>',
             model_name = model.name,
-            line_number = problem['line-number'],
-            problem_type = problem['type'],
+            line_number = problem.line_number,
+            problem_type = problem.severity,
             message = problem_text(problem, test_suite),
             ldraw_code = ldraw_code,
         )
@@ -129,22 +139,20 @@
     colorama.init()
     messages = []
     for problem in report['problems']:
-        if problem['type'] == 'error':
+        if problem.severity == 'error':
             text_colour = colorama.Fore.LIGHTRED_EX
-        elif problem['type'] == 'warning':
-            text_colour = colorama.Fore.LIGHTYELLOW_EX
-        elif problem['type'] == 'notice':
+        elif problem.severity == 'notice':
             text_colour = colorama.Fore.LIGHTBLUE_EX
         else:
             text_colour = ''
-        ldraw_code = model.body[problem['body-index']].textual_representation()
+        ldraw_code = model.body[problem.body_index].textual_representation()
         message = str.format(
             '{text_colour}{model_name}:{line_number}: {problem_type}: {message}'
             '{colour_reset}\n\t{ldraw_code}',
             text_colour = text_colour,
             model_name = model.name,
-            line_number = problem['line-number'],
-            problem_type = problem['type'],
+            line_number = problem.line_number,
+            problem_type = problem.severity,
             message = problem_text(problem, test_suite),
             colour_reset = colorama.Fore.RESET,
             ldraw_code = ldraw_code,
@@ -152,6 +160,10 @@
         messages.append(message)
     return '\n'.join(messages)
 
+def all_warning_types(test_suite):
+    for test_function in test_suite['tests']:
+        yield from test_function.ldcheck_problem_types.values()
+
 if __name__ == '__main__':
     from pprint import pprint
     pprint(load_tests())

mercurial