Package datk :: Package core :: Module tester
[hide private]
[frames] | no frames]

Source Code for Module datk.core.tester

  1  from threading import Thread, Lock 
  2  from time import sleep, time 
  3  from distalgs import Process, Algorithm 
  4   
  5  TIMEOUT = 5 
  6   
  7  _lock = Lock() 
  8  _num_tests = 0 
  9  _failed_tests = set() 
 10   
11 -def test(f=None, timeout=TIMEOUT, main_thread=False, test=True):
12 """ 13 Decorator function test to run distributed algorithm tests in safe environment. Logs failed tests. 14 15 @param f: the test (a function) to run. 16 @param timeout: the number of seconds to allow the test to run, before timing it out (causing it to fail). 17 @param main_thread: True iff the test cannot run on a thread other than the main thread. 18 @param test: If false, skips testing this function. Useful because it can be set to default to false, and then set to True for a select few tests currently being tested. 19 """ 20 if not test: return lambda f: f 21 22 def test_decorator(f): 23 global _lock 24 25 def test_f(): 26 global _num_tests 27 global _failed_tests 28 try: 29 f() 30 except Exception, e: 31 _failed_tests.add(f.__name__) 32 print_with_underline("TEST "+f.__name__+" FAILED.") 33 raise e 34 finally: 35 _num_tests+=1
36 with _lock: 37 if main_thread: 38 status = "Running test "+f.__name__+" on main thread." 39 print status 40 test_f() 41 print "#"*len(status) 42 else: 43 t = Thread(target = test_f) 44 t.daemon = True 45 46 start_time = time() 47 t.start() 48 t.join(timeout) 49 end_time = time() 50 if end_time - start_time >= timeout: 51 _failed_tests.add(f.__name__) 52 print_with_underline(f.__name__ + " TIMED OUT AFTER " + str(timeout) + "s") 53 else: 54 print_with_underline(f.__name__ + " RAN IN " +str(end_time-start_time) + "s") 55 if f is None: 56 return test_decorator 57 else: 58 test_decorator(f) 59 return f 60 64
65 -def summarize():
66 """Called at the end of a test suite. Prints out summary of failed tests""" 67 global _num_tests 68 global _failed_tests 69 70 print _num_tests, "tests ran with", len(_failed_tests), "failures:", sorted(list(_failed_tests)) 71 72 _num_tests = 0 73 _failed_tests = set()
74 75 import matplotlib.pyplot as plt
76 -def benchmark(Algorithm, Network, test):
77 """ 78 Benchmarks the Algorithm on a given class of Networks. Samples variable network size, and plots results. 79 80 @param Algorithm: a subclass of SynchronousAlgorithm, the algorithm to test. 81 @param Network: a subclass of Network, the network on which to benchmark the algorithm. 82 @param test: a function that may throw an assertion error 83 """ 84 85 def sample(Algorithm, Network, test): 86 """ 87 Runs the Algorithm on Networks of the given type, varying n. 88 After every execution, runs test on the resultant Network. 89 90 @param Algorithm: a subclass of SynchronousAlgorithm, the algorithm to test. 91 @param Network: a subclass of Network, the network on which to benchmark the algorithm. 92 @param test: a function that may throw an assertion error 93 @return: (size, time, comm) where size is a list of values of network size, 94 and time and comm are lists of corresponding values of time and communication complexities. 95 """ 96 size = [] 97 time = [] 98 comm = [] 99 n, lgn = 2, 1 100 max_time = 0 101 max_comm = 0 102 print "Sampling n = ...", 103 while max(max_time, max_comm) < 10000 and n < 500: 104 105 #Progress 106 if n == 2: 107 print "\b\b\b\b"+str(n)+"...", 108 else: 109 print "\b\b\b\b, "+str(n)+"...", 110 111 cur_times = [] 112 cur_comms = [] 113 for i in xrange( max(4, 2+lgn) ): 114 A = Algorithm(params={'draw': False, 'verbosity': Algorithm.SILENT}) 115 x = Network(n) 116 A(x) 117 try: 118 test(x) 119 except AssertionError, e: 120 print "Algorithm Failed" 121 return None 122 else: 123 cur_times.append(A.r) 124 cur_comms.append(A.message_count) 125 126 size.append(n) 127 time.append(A.r) 128 comm.append(A.message_count) 129 max_time = max(max_time, A.r) 130 max_comm = max(max_comm, A.message_count) 131 132 #TODO here, decide whether need more samples for this n, based on cur_times and cur_comms variance 133 n*=2 134 lgn += 1 135 print " DONE" 136 return size, time, comm
137 138 def averages(x,y): 139 """ 140 Groups x's with the same value, averages corresponding y values. 141 142 @param x: A sorted list of x values 143 @param y: A list of corresponding y values 144 @return: (x grouped by value, corresponding mean y values) 145 146 Example: 147 148 averages([1,1,2,2,2,3], [5,6,3,5,1,8]) --> ([1, 2, 3], [5.5, 3.0, 8.0]) 149 150 """ 151 new_x = [x[0]] 152 new_y = [] 153 154 cur_x = new_x[0] 155 cur_ys = [] 156 for x_i, y_i in zip(x,y): 157 if x_i == cur_x: 158 cur_ys.append(y_i) 159 else: 160 new_y.append( sum(cur_ys)/float(len(cur_ys) ) ) 161 new_x.append( x_i ) 162 cur_ys = [y_i] 163 cur_x = x_i 164 new_y.append( sum(cur_ys)/float(len(cur_ys) ) ) 165 return new_x, new_y 166 167 def plot(x, y, title): 168 """Plots the points (x[i],y[i]) for all i, fig.""" 169 fig, ax = plt.subplots() 170 171 x_ave,y_ave = averages(x,y) 172 173 ax.scatter(x, y, label="data", color='b') 174 ax.scatter(x_ave, y_ave, label="means", color='r') 175 176 ax.set_xlim( xmin=0 ) 177 ax.set_ylim( ymin=0 ) 178 ax.legend(loc='center left', bbox_to_anchor=(1, 0.5)) 179 ax.set_title(title) 180 ax.set_xlabel(Network.__name__ +' size') 181 182 data = sample(Algorithm, Network, test) 183 if data == None: return 184 size, time, comm = data 185 plot(size, time, Algorithm.__name__ + ' Time Complexity') 186 plot(size, comm, Algorithm.__name__ + ' Communication Complexity') 187