Source code for tl.testing.thread
# Copyright (c) 2011-2012 Thomas Lotze
# See also LICENSE.txt
"""Working with threads in test code.
"""
import sys
import threading
import unittest
[docs]class ThreadAwareTestCase(unittest.TestCase):
"""TestCase implementation that adds some features for handling threads.
While the default ``unittest.TestCase`` doesn't concern itself with
threading at all, this implementation adds the following:
- interplay with threads such that exceptions and failures that occur in a
thread get collected along those from the main thread
- reporting of threads left behind by a test
- a convenience helper for running code in a daemon thread
- a convenience helper for counting threads started during the current
test method's execution
The ``ThreadAwareTestCase`` is instantiated the same way as
``unittest.TestCase``.
"""
def __init__(self, *args):
super(ThreadAwareTestCase, self).__init__(*args)
self.ExceptionReportingThread = ExceptionReportingThread.bind(self)
def run(self, result=None):
"""Thread-aware wrapper of ``unittest.TestCase.run``; see there.
.. sphinx-autodoc-skip
"""
self.result = self.defaultTestResult() if result is None else result
with ThreadJoiner(0, check_alive=False) as joiner:
self._before = joiner.before
result = super(ThreadAwareTestCase, self).run(self.result)
del self._before
if joiner.left_behind:
self.report_threads_left_behind(joiner.left_behind)
return result
[docs] def active_count(self):
"""Count threads started during the current test method's execution.
Returns an integer.
"""
return len(set(threading.enumerate()) - self._before)
[docs] def run_in_thread(self, func, report=True):
"""Run the given function in a daemon thread that handles exceptions.
Using this thread helper is necessary in order for the thread to play
with ``ThreadAwareTestCase``'s reporting of errors and failures that
occurred in threads.
Passing a false value for the ``report`` option causes unhandled
exceptions in the thread to be swallowed instead of reported. This
prevents them from appearing in the middle of more interesting test
output in the case that they are of no concern themselves.
This method starts the thread in daemon mode and blocks until the
thread has actually started executing.
Returns the thread object.
"""
factory = (self.ExceptionReportingThread if report
else ExceptionSilencingThread)
thread = factory(target=func)
thread.daemon = True
thread.start()
return thread
[docs] def report_threads_left_behind(self, threads):
"""Hook reporting threads left behind by a test to the test runner.
This hook is meant to do whatever the test runner being used needs to
be done to display information on the collection of threads passed to
this hook. The default implementation is meant for a console test
runner: it formats the threads as a piece of text using the
``format_report_on_threads_left_behind`` hook and prints that to
standard output.
"""
msg = self.format_report_on_threads_left_behind(threads)
sys.stdout.write(msg)
class _ExceptionHandlingThread(threading.Thread):
"""Base class of any threads that need their own run method.
.. sphinx-autodoc-skip
"""
def start(self):
"""Reporting-aware re-implementation of ``Thread.start``.
.. sphinx-autodoc-skip
"""
run = self.run
self.started = threading.Event()
try:
self.run = self._run
super(_ExceptionHandlingThread, self).start()
self.started.wait(1)
if not self.started.is_set():
raise RuntimeError('Time-out waiting for thread to start.')
finally:
self.run = run
[docs]class ExceptionReportingThread(_ExceptionHandlingThread):
"""Thread implementation which reports exceptions to a test result object.
The reporting-thread class is not useful if instantiated directly; it must
be bound to an instance of the ``ThreadAwareTestCase`` first. It gets thus
access to the test result object, to which it will report if an exception
occurs in the thread's business code. An ``AssertionError`` will be
reported as a test failure, any other exception will be reported as a test
error.
"""
_test = None
def _run(self):
self.started.set()
try:
super(ExceptionReportingThread, self).run()
except AssertionError:
self._test.result.addFailure(self._test, sys.exc_info())
except Exception:
self._test.result.addError(self._test, sys.exc_info())
@classmethod
[docs] def bind(cls, test):
"""Bind the class to a specific test.
Binding the reporting-thread class to a test is necessary for
instances of it to be able to report errors and failures to the test
result.
Returns a subclass of the reporting-thread class.
"""
class ExceptionReportingThread(cls):
pass
ExceptionReportingThread._test = test
return ExceptionReportingThread
[docs]class ExceptionSilencingThread(_ExceptionHandlingThread):
"""Thread implementation which swallows unhandled exceptions.
Swallowing exceptions is useful if it is clear that they are of no
interest and shall not mix with more useful test output.
"""
def _run(self):
self.started.set()
try:
super(ExceptionSilencingThread, self).run()
except Exception:
pass
[docs]class ThreadJoiner(object):
"""Context manager that tries to join any threads started by its suite.
This context manager is instantiated with a mandatory ``timeout``
parameter and an optional ``check_alive`` switch. The time-out is applied
when joining each of the new threads started while executing the context
manager's code suite. If ``check_alive`` has a true value (the default),
a ``RuntimeError`` is raised if a thread is still alive after the attempt
to join timed out.
Returns an instance of itself upon entering. This instance has a
``before`` attribute that is a collection of all threads active when the
manager was entered. After the manager exited, the instance has another
attribute, ``left_behind``, that is a collection of any threads that could
not be joined within the time-out period. The latter is obviously only
useful if ``check_alive`` is set to a false value.
"""
def __init__(self, timeout, check_alive=True):
self.timeout = timeout
self.check_alive = check_alive
def __enter__(self):
self.before = set(threading.enumerate())
return self
def __exit__(self, exc_type, exc_val, exc_tb):
for thread in set(threading.enumerate()) - self.before:
thread.join(self.timeout)
if self.check_alive and thread.is_alive():
raise RuntimeError('Timeout joining thread %r' % thread)
self.left_behind = sorted(
set(threading.enumerate()) - self.before, key=lambda t: t.name)