from __future__ import with_statement
import imp
import inspect
import os
import sys
from attest import ast, statistics
from attest.codegen import to_source, SourceGenerator
__all__ = (
'COMPILES_AST',
'ExpressionEvaluator',
'TestFailure',
'assert_hook',
'AssertTransformer',
'AssertImportHook',
)
try:
compile(ast.parse('pass'), '<string>', 'exec')
except TypeError:
COMPILES_AST = False
else:
COMPILES_AST = True
[docs]class ExpressionEvaluator(SourceGenerator):
"""Evaluates `expr` in the context of `globals` and `locals`, expanding
the values of variables and the results of binary operations, but
keeping comparison and boolean operators.
.. testsetup::
from attest import ExpressionEvaluator
>>> var = 1 + 2
>>> value = ExpressionEvaluator('var == 5 - 3', globals(), locals())
>>> repr(value)
'(3 == 2)'
>>> bool(value)
False
.. versionadded:: 0.5
"""
def __init__(self, expr, globals, locals):
self.expr = expr
self.globals, self.locals = globals, locals
self.result = []
node = ast.parse(self.expr).body[0].value
self.visit(node)
def __repr__(self):
return ''.join(self.result)
def __str__(self):
return '\n'.join((self.expr, repr(self)))
def __nonzero__(self):
return bool(eval(self.expr, self.globals, self.locals))
def eval(self, node):
return eval(to_source(node), self.globals, self.locals)
def write(self, s):
self.result.append(str(s))
def visit_Name(self, node):
value = self.eval(node)
if getattr(value, '__name__', None):
self.write(value.__name__)
else:
self.write(repr(value))
def generic_visit(self, node):
self.write(repr(self.eval(node)))
visit_BinOp = visit_Subscript = generic_visit
visit_ListComp = visit_GeneratorExp = generic_visit
visit_SetComp = visit_DictComp = generic_visit
visit_Call = visit_Attribute = generic_visit
class TestFailure(AssertionError):
"""Extended :exc:`AssertionError` used by the assert hook.
:param value: The asserted expression evaluated with
:class:`ExpressionEvaluator`.
:param msg: Optional message passed to the assertion.
.. versionadded:: 0.5
"""
def __init__(self, value, msg=''):
self.value = value
AssertionError.__init__(self, msg)
[docs]def assert_hook(expr, msg='', globals=None, locals=None):
"""Like `assert`, but using :class:`ExpressionEvaluator`. If
you import this in test modules and the :class:`AssertImportHook` is
installed (which it is automatically the first time you import from
:mod:`attest`), `assert` statements are rewritten as a call to
this.
The import must be a top-level *from* import, example::
from attest import Tests, assert_hook
.. versionadded:: 0.5
"""
statistics.assertions += 1
if globals is None:
globals = inspect.stack()[1][0].f_globals
if locals is None:
locals = inspect.stack()[1][0].f_locals
value = ExpressionEvaluator(expr, globals, locals)
if not value:
raise TestFailure(value, msg)
# Build AST nodes on 2.5 more easily
def _build(node, **kwargs):
node = node()
for key, value in kwargs.iteritems():
setattr(node, key, value)
return node
class AssertImportHookEnabledDescriptor(object):
def __get__(self, instance, owner):
return any(isinstance(ih, owner) for ih in sys.meta_path)
[docs]class AssertImportHook(object):
"""An :term:`importer` that transforms imported modules with
:class:`AssertTransformer`.
.. versionadded:: 0.5
"""
#: Class property, :const:`True` if the hook is enabled.
enabled = AssertImportHookEnabledDescriptor()
@classmethod
[docs] def enable(cls):
"""Enable the import hook."""
cls.disable()
sys.meta_path.insert(0, cls())
@classmethod
[docs] def disable(cls):
"""Disable the import hook."""
sys.meta_path[:] = [ih for ih in sys.meta_path
if not isinstance(ih, cls)]
def __init__(self):
self._cache = {}
def __enter__(self):
sys.meta_path.insert(0, self)
def __exit__(self, exc_type, exc_value, traceback):
sys.meta_path.remove(self)
def find_module(self, name, path=None):
lastname = name.rsplit('.', 1)[-1]
try:
self._cache[name] = imp.find_module(lastname, path), path
except ImportError:
return
return self
def load_module(self, name):
if name in sys.modules:
return sys.modules[name]
source, filename, newpath = self.get_source(name)
(fd, fn, info), path = self._cache[name]
if source is None:
return imp.load_module(name, fd, fn, info)
transformer = AssertTransformer(source, filename)
if not transformer.should_rewrite:
fd, fn, info = imp.find_module(name.rsplit('.', 1)[-1], path)
return imp.load_module(name, fd, fn, info)
try:
return transformer.make_module(name, newpath)
except Exception, err:
raise ImportError('cannot import %s: %s' % (name, err))
def get_source(self, name):
try:
(fd, fn, info), path = self._cache[name]
except KeyError:
raise ImportError(name)
code = filename = newpath = None
if info[2] == imp.PY_SOURCE:
filename = fn
with fd:
code = fd.read()
elif info[2] == imp.PY_COMPILED:
filename = fn[:-1]
with open(filename, 'U') as f:
code = f.read()
elif info[2] == imp.PKG_DIRECTORY:
filename = os.path.join(fn, '__init__.py')
newpath = [fn]
with open(filename, 'U') as f:
code = f.read()
return code, filename, newpath