Package dimer :: Package nnet :: Module nnet_tests
[hide private]
[frames] | no frames]

Source Code for Module dimer.nnet.nnet_tests

 1   
 2  import unittest 
 3   
 4  import numpy as np 
 5  rng = np.random.RandomState() 
 6  import pandas as pd 
 7   
 8  from . import adjust_lr, shift_data, scale_data, fit_data 
 9   
10 -class TestLr( unittest.TestCase ):
11 - def setUp(self):
12 rng = np.random.RandomState(10) 13 self.rx = np.array( rng.rand( 15, 2, 10) ) 14 self.xpn = pd.Panel(self.rx, items=map(lambda i: "anchor_%d" % i, range(self.rx.shape[0])), 15 major_axis=map(lambda i: "track_%d" % i, range(self.rx.shape[1])), 16 minor_axis=map(lambda i: "position_%d" % i, range(self.rx.shape[2])) ) 17 18 self.trials = range( rng.randint(0, self.rx.shape[1]*self.rx.shape[2] ) )
19
20 - def test_seqmonotonicity(self):
21 self.assertEqual( adjust_lr([4], 0.1), 0.1 ) 22 23 self.assertLess( adjust_lr([4, 5], 0.1), 0.1 ) 24 self.assertEqual( adjust_lr([5, 3], 0.1) , 0.1 ) 25 self.assertEqual( adjust_lr([5, 5], 0.1) , 0.1 )
26 27
28 - def test_shift(self):
29 oxpn = self.xpn.copy() 30 xpn, meandf = shift_data(self.xpn) 31 32 for trial in self.trials: 33 track = rng.choice( xrange(self.rx.shape[1]) ) 34 position = rng.choice( xrange(self.rx.shape[2]) ) 35 36 self.assertAlmostEqual(xpn.values[:,track, position].mean(), 0) 37 self.assertAlmostEqual(oxpn.values[:,track, position].mean(), meandf[track, position])
38
39 - def test_scale(self):
40 oxpn = self.xpn.copy() 41 xpn, sddf = scale_data(self.xpn) 42 43 for trial in self.trials: 44 track = rng.choice( xrange(self.rx.shape[1]) ) 45 position = rng.choice( xrange(self.rx.shape[2]) ) 46 47 self.assertAlmostEqual(xpn.values[:,track, position].std(), 1) 48 self.assertAlmostEqual(oxpn.values[:,track, position].std(), sddf[track, position])
49
50 - def test_fit(self):
51 xpn = fit_data(self.xpn) 52 53 for trial in self.trials: 54 track = rng.choice( xrange(self.rx.shape[1]) ) 55 position = rng.choice( xrange(self.rx.shape[2]) ) 56 57 self.assertLessEqual(np.max( np.abs(xpn.values[:,track, position]) ), 1)
58