1
2 import unittest, tempfile
3
4 from data import Dataset, AnchorDataset, TrainAnchorDataset
5 import numpy as np
6 import pandas as pd
7
8 rng = np.random.RandomState(10)
12 self.X = rng.rand( 100, 800 )
13
15 "0-mean and 1-variance features"
16
17
18 def check_norm(x,m,v):
19 self.assertEqual(m.shape, v.shape)
20 self.assertEqual(m.shape, tuple(x.shape[1:]))
21 print x.reshape(x.shape[0], -1)
22 print m.reshape((-1,) )
23 print v.reshape((-1,) )
24
25 for f in range(x.shape[1]):
26 self.assertAlmostEqual( x[:,f].mean(), 0)
27 if v[f] != 0:
28 self.assertAlmostEqual( x[:,f].std(), 1)
29 else:
30 self.assertAlmostEqual( x[:,f].std(), 0)
31
32 X = self.X
33 check_norm(*Dataset.normalize_features(X))
34 X[:,0] = 0
35 print X[:,0]
36 check_norm( *Dataset.normalize_features(X) )
37
39 "features in [0,1]"
40
41 print self.X
42 fx = Dataset.fit_features(self.X)
43 self.assertAlmostEqual( fx.min(), -1 )
44 self.assertAlmostEqual( fx.max(), 1 )
45 self.assertEqual(fx.shape, self.X.shape)
46
51
54 (n, tr, w) = (rng.randint(10, 100), rng.randint(2, 5),
55 rng.randint(200, 500))
56 self.bs = rng.randint(2, n/4)
57
58 self.X = rng.rand(n, tr, w)
59 self.Y = rng.rand(n,)
60 self.T = np.array( map(lambda v: (v > 0 and [1] or [0])[0], self.Y),
61 np.int )
62 self.T[0] = 0; self.T[1] = 1
63
64 self.gnames = map(lambda i: "gene%d"%i, range(self.X.shape[0]))
65 self.tracks = map(lambda i: "track%d"%i, range(self.X.shape[1]))
66 self.width = map(lambda i: "pos%d"%i, range(self.X.shape[2]))
67 self.labels = ("R", "I")
68
69 self.pX = pd.Panel( self.X, items=self.gnames,
70 major_axis = self.tracks, minor_axis = self.width)
71
72 self.dfT = pd.DataFrame({"label_code" : self.T,
73 "label_name" : map(lambda v: self.labels[v], self.T)})
74 self.sY = pd.Series( self.Y, index = self.gnames)
75
76
77 self.labds = AnchorDataset(self.pX, self.sY, self.dfT)
78 self.ds = AnchorDataset(self.pX, self.sY, None)
79
81 self.assertEqual(set( self.labds.label_names), set(self.labels))
82 self.assertEqual(set( self.labds.track_names), set(self.tracks))
83
84 self.assertEqual(self.ds.label_names, None)
85 self.assertEqual(self.ds.track_names, self.tracks)
86
88 "dataset on theano shared vars"
89
90 self.assertTrue( np.all( self.ds.shX.get_value() == self.ds.X ) )
91
92
94
95 ds = TrainAnchorDataset(self.pX, self.sY, self.dfT, self.bs)
96
97
98
99 self.assertEqual(ds.train_batches + ds.valid_batches,
100 range(ds.n_batches))
101
102
103 ds = TrainAnchorDataset(self.pX, self.sY, self.dfT, self.bs, rng=rng)
104 self.assertNotEqual(ds.train_batches + ds.valid_batches,
105 range(ds.n_batches))
106 self.assertEqual(set( ds.train_batches + ds.valid_batches),
107 set( range(ds.n_batches)) )
108
109
110 @unittest.SkipTest
112 ds = AnchorDataset(self.pX, None, None, self.bs)
113
114 self.assertEqual( 5 * list( ds.iter_train(1) ), 5 * ds.train_batches )
115 self.assertEqual( 7 * list( ds.iter_valid(1) ), 7 * ds.valid_batches )
116
117 self.assertEqual( list( ds.iter_train(5) ), 5 * ds.train_batches )
118 self.assertEqual( list( ds.iter_valid(7) ), 7 * ds.valid_batches )
119
121 from archive import __SPEC_SEP__, __HDF_SUFFIX__
122
123 with tempfile.NamedTemporaryFile(suffix="."+__HDF_SUFFIX__) as fd:
124 path = __SPEC_SEP__.join( (fd.name, "empty") )
125 lds = self.labds
126 lds.dump(path)
127 ods = AnchorDataset._from_archive(path, True)
128
129 self.assertEqual(ods.label_names, lds.label_names)
130 self.assertEqual(ods.track_names, lds.track_names)
131
132 self.assertAlmostEqual( np.max( np.abs( ods.X - lds.X ) ), 0 )
133 self.assertTrue( np.all(ods.sY == lds.sY) )
134 self.assertTrue( np.all(ods.dfT == lds.dfT) )
135
136 ods = AnchorDataset._from_archive(path, False)
137 ldsX, m, sd = lds.normalize_features(lds.X.reshape(self.X.shape[0], -1))
138 self.assertAlmostEqual( np.max( np.abs( ods.X - ldsX.reshape(ods.X.shape) ) ), 0 )
139
140
141 with tempfile.NamedTemporaryFile(suffix="."+__HDF_SUFFIX__) as fd:
142 path = __SPEC_SEP__.join( (fd.name, "empty") )
143 lds = TrainAnchorDataset( self.pX, self.sY, self.dfT, self.bs )
144 lds.dump(path)
145 ods = TrainAnchorDataset._from_archive(path, False, self.bs)
146
147 self.assertEqual(ods.train_batches, lds.train_batches)
148 self.assertEqual(ods.valid_batches, lds.valid_batches)
149