Package dimer :: Package nnet :: Module config_spec
[hide private]
[frames] | no frames]

Source Code for Module dimer.nnet.config_spec

  1  """classes to handle configuration files""" 
  2   
  3  import os, logging 
  4  from ConfigParser import SafeConfigParser 
  5  from collections import namedtuple 
  6   
  7  import numpy as np 
  8  import pandas as pd 
  9   
 10  from .. import archive 
 11   
 12  logging.getLogger(__name__) 
 13  log = logging.getLogger() 
14 15 -def _open_cfg(path):
16 cfg = SafeConfigParser() 17 if not len(cfg.read([path])): 18 raise ValueError("cannot load %s" % path) 19 return cfg
20
21 -def _check_exist_f(s):
22 if not os.path.isfile(s): 23 import warnings 24 warnings.warn("non-existent input file '%s'" % s) 25 return s
26
27 -def _check_get_pairs(s, tp=int):
28 items = map(tp, s.split()) 29 p = [] 30 npairs = len(items) / 2 31 for i in range(npairs): 32 p.append( (items[i*2], items[i*2+1]) ) 33 return tuple(p)
34
35 -def _check_get_singles(s, tp=int):
36 return tuple( map(tp, s.split()) )
37
38 -class CfgFactory( object ):
39 """abstract class with factory method from a config file 40 41 to subclass, need to define 42 43 _types: types for properties 44 _section: section on the confifg file""" 45 46 @classmethod
47 - def _from_settings(cls, path, new=tuple.__new__, len=len):
48 'make a new object from a settins file' 49 50 if not len(cls._fields) == len(cls._types): 51 raise ValueError("a type (found %d) / key (found %d) needed for %s" % (len(cls._fields), 52 len(cls._types), str(cls))) 53 54 cfg = _open_cfg(path) 55 iterable = map(lambda (k,t): t( cfg.get(cls._section, k) ), 56 zip(cls._fields, cls._types)) 57 result = new(cls, iterable) 58 if len(result) != len(cls._fields): 59 raise TypeError('expected %d arguments, got %d' % (len(cls._fields), 60 len(result))) 61 result._check_consistency() 62 return result
63
64 - def _check_consistency(self):
65 "check as much as you can that values of params make sense" 66 67 pass
68
69 ## TODO: change this to CnnModelSpec 70 -class ModelSpec (namedtuple('MetaParams', ("nkerns rfield pool lreg_size")), CfgFactory):
71 __slots__ = () 72 73 _types = (_check_get_singles, _check_get_pairs, _check_get_pairs, int) 74 _section = "model" 75
76 - def _check_consistency(self):
77 if len(self.nkerns) != len(self.pool) or len(self.nkerns) != len(self.rfield): 78 raise ValueError(" len(self.nkerns) != len(self.pool) or len(self.nkerns) != len(self.rfield)")
79 80 @property
81 - def cp_arch(self):
82 return (self.nkerns, self.rfield, self.pool)
83
84 -class MtrainSpec( namedtuple("MtrainSpec", "batch_size l1_rate l2_rate lr tau momentum_mult nepochs minepochs patience"), CfgFactory ):
85 __slots__ = () 86 87 _types = (int, float, float, float, int, float, int, int, int) 88 _section = "modtrain" 89
90 - def _check_consistency(self):
91 if self.minepochs > self.nepochs: 92 raise ValueError("minepochs (%d) > nepochs (%d)", 93 self.minepochs, self.nepochs)
94
95 #class DtrainSpec( namedtuple("DtrainSpec", "batch_size l1_rate l2_rate lr tau momentum_mult nepochs minepochs patience"), CfgFactory ): 96 # __slots__ = () 97 # 98 # _types = (int, float, float, float, int, float, int, int, int) 99 # _section = "dttrain" 100 # 101 # def _check_consistency(self): 102 # pass 103 # 104 #class OutputSpec( namedtuple("OutputSpec", "oid weight_log learn_log model_info best_inputX best_inputY best_example_inputX best_example_inputY"), CfgFactory ): 105 # __slots__ = () 106 # 107 # _types = (str, str, str, str, str, str, str, str) 108 # _section = "deliverables" 109 # 110 # def _check_consistency(self): 111 # pass 112 113 -class AESpec(namedtuple("AutoEncoderSpec", "rec_error minepochs maxepochs batch_size noise"), CfgFactory):
114 __slots__ = () 115 116 _types = (float, int, int, int, float) 117 _section = "ae" 118
119 - def _check_consistency(self):
120 pass
121
122 -class DataSpec( namedtuple("DataSpec", "dataname tracks width train_s valid_s labels track_names label_names batch_size train_batches valid_batches") ):
123 __slots__ = () 124 125 126 @classmethod
127 - def batches_from_data(self, tot_size, batch_s, valid_s, valid_idx, rng):
128 if valid_s <= 0 or valid_s >= 1: 129 raise ValueError("valid_s (%f) should be between (0, 1) ", valid_s) 130 131 if batch_s > tot_size * min(valid_s, 1-valid_s): 132 raise ValueError("batch size (%d) too big > min(valid_s=%d, train_s=%d)", 133 batch_s, tot_size * valid_s, tot_size *(1-valid_s)) 134 135 all_batches = range( tot_size / batch_s ) 136 try: 137 valid_batches = all_batches[valid_idx:valid_idx+int(len(all_batches)*valid_s)] 138 except IndexError: 139 raise ValueError("valid_idx (%d) should be between 0 and %d", 140 valid_idx, len(all_batches)-1) 141 train_batches = list( set(all_batches) - set(valid_batches) ) 142 assert set(train_batches + valid_batches) == set(all_batches) 143 assert len( set(train_batches) & set(valid_batches) ) == 0 144 rng.shuffle(train_batches) 145 rng.shuffle(valid_batches) 146 logging.info("train batches: %s", str(train_batches)) 147 logging.info("valid batches: %s", str(valid_batches)) 148 return (train_batches, valid_batches)
149 150 151 @classmethod
152 - def _from_archive(cls, path, batch_s, rng, valid_s, valid_idx):
153 """initialize from archive 154 155 path: dataset specification path 156 batch_s: batch size 157 train_s: train size (default = 50%) 158 valid_s: vlaidation size (default 25% )""" 159 160 161 objX = archive.load_object( archive.archname(path), 162 archive.basename(path) + "/X") 163 objT = archive.load_object( archive.archname(path), 164 archive.basename(path) + "/T") 165 166 gene_names = objX.items 167 track_names = objX.major_axis 168 width = objX.minor_axis.shape[0] 169 label_names = np.unique( objT["label_name"].values ).tolist() 170 171 #store = pd.HDFStore( archive.archname(path) ) 172 #gene_names = store[ archive.basename(path)+"/X" ].items 173 #track_names = store[ archive.basename(path)+"/X" ].major_axis 174 #width = store[ archive.basename(path)+"/X" ].minor_axis.shape[0] 175 #label_names = np.unique( store[ archive.basename(path)+"/T" ]["label_name"].values ).tolist() 176 177 train_batches, valid_batches = cls.batches_from_data(len(gene_names), 178 batch_s, valid_s, 179 valid_idx, rng) 180 #store.close() 181 return cls._make( (archive.basename(path), len(track_names), width, 182 batch_s * len(train_batches), batch_s * len(valid_batches), len(label_names), 183 track_names, label_names, batch_s, train_batches, valid_batches) )
184 185 186 if __name__ == "__main__": 187 pass 188