Source code for fit_neuron.optimize.threshold

"""
threshold
~~~~~~~~~~~~~~~~~~~~
Parameter extraction of threshold parameters from voltage clamp data.   
"""

import sys
import os 
import numpy as np
import random 
from numpy import linalg
from numpy import exp, Inf, floor, ceil, zeros 
import multiprocessing
from fit_neuron import data

[docs]def get_grad_and_hessian((thresh_param,X_cat)): """ Takes a subset of the data and returns gradient and hessian for this subset of the data. This function is called by par_calc_log_like_update and allows this function to work in parallel, map-reduce style. Computes gradient vector as well as Hessian matrix of log likelihood w.r.t. threshold parameters, evaluated at thresh_param. The grad and hess can then be used to update thresh parameters via a Newton-Raphson step. """ theta_len = len(thresh_param) thresh_param = thresh_param.reshape((1,theta_len)) grad_vec = np.zeros( (1,theta_len) ) hess_mat = np.zeros( (theta_len,theta_len) ) data_len = len(X_cat) for t in range(data_len): X_thresh = X_cat[t,:].reshape((1,theta_len)) if np.isnan(X_thresh[0,0]): continue if t+1 < data_len: if np.isnan(X_cat[t+1,0]): grad_vec = grad_vec + X_thresh continue exp_val = np.exp(thresh_param.dot(X_thresh.T))[0][0] grad_vec = grad_vec - X_thresh * exp_val hess_mat = hess_mat - exp_val * X_thresh.T.dot(X_thresh) cur_tuple = (grad_vec,hess_mat) return cur_tuple
[docs]def par_calc_log_like_update(thresh_param,X_cat_split,process_ct): """ Parallel version of Newton update. This method is called by :func:`max_likelihood`. The function splits data into chunks, passes chuncks to different processors, collects the result to obtain the gradient and hessian of the whole time series, then performs Newon update. """ pool = multiprocessing.Pool(processes=process_ct) inputs = [(thresh_param[:],X_cat_short) for X_cat_short in X_cat_split] results = pool.map(get_grad_and_hessian, inputs) pool.close() pool.join() grad_sum = results[0][0] hess_sum = results[0][1] for r in results[1:]: grad_sum += r[0] hess_sum += r[1] theta_new = thresh_param - np.linalg.pinv(hess_sum).dot(grad_sum.T).T return grad_sum,theta_new[0]
[docs]class StochasticThresh(): r""" Class for the threshold component of a spiking neuron. The main functionality of this class is to determine spike times stochastically. :param dt: time step. :keyword t_bins: a list that defines the :math:`b_i` that define indicator functions :math:`I_{[0,b_i]}(t)` (see below). :keyword volt_adapt_time_const: a list that defines the time constants :math:`r_i` of the voltage chasing currents (see below). The stochastic neuron has the following *hazard rate*: .. math:: h(t) = \exp \left({\bf w}_t^{\top} {\bf X}_t (t) \right) where .. math:: {\bf X}_t (t) = [1,V(t),I_1(t),\hdots,I_m(t),Q_1(t),\hdots,Q_l(t)]^{\top}. where the :math:`I_i(t) = I_{[0,b_i]}(t)` parameters are the indicator variables, and the :math:`Q_j(t)` parameters are probability currents which shall be referred to as *voltage chasing currents*. These currents give the stochastic spike emission process a component that adapts to the history of the voltage. The equations used for the voltage chasing currents are: .. math:: \frac{dQ_i}{dt} = r_i (V - Q_i) After the neuron spikes, the voltage chasing currents are reset to the value of the voltage immediately following the spike: .. math:: Q_i \gets V_r The hazard rate is computed at each time step and compared to a uniformly distributed random number to determine whether the neuron spikes here. The computation of :math:`{\bf X}_t (t)` at each time step is done by :meth:`update_X_arr`, while the inner product with the parameter vector :math:`{\bf w}_t` and the random decision of whether a spike occurs or not is taken by :meth:`update`. """ def __init__(self,t_bins, volt_adapt_time_const=[], dt=0.0001, thresh_param=None, thresh_param_dict={}): self.dt = dt self.t_bins = t_bins self.bin_count = len(t_bins) #: max amount of time during which we care about previous spikes, see self.update self.t_max = t_bins[-1] self.t_hist = [] self.X_arr = None #: these are the time constants of the voltage chasing parameters self.volt_adapt_time_const = volt_adapt_time_const #: the actual value of the currents which will be an array self.volt_adapt_currents = None #: these parameters take the internal values and map them to a spike probability self.thresh_param = thresh_param self.thresh_param_dict = thresh_param_dict #: current value of the spiking probability self.spike_prob = None self.param_ct = len(t_bins) + len(volt_adapt_time_const) + 2
[docs] def set_param(self,param): """ After the optimization is done, this method allows the final value of the parameter array to be set, and parses this parameter array into a parameter dictionary. """ self.thresh_param = param param_dict = {'full_thresh_param':list(self.thresh_param), 'v_param':self.thresh_param[0], '1_param':self.thresh_param[1], 'thresh_shape_param':list(self.thresh_param[2:2+self.bin_count]), 'volt_adapt_param':list(self.thresh_param[2+self.bin_count:]), 'volt_adapt_time_const':self.volt_adapt_time_const, "dt":self.dt, 't_bins':self.t_bins, } self.param_dict = param_dict
def _compute_ind_arr(self): """ Computes array of integer-valued indicator functions that only depend on the spike history. The coefficients of the paramater array will do an inner product with this indicator array to specify the shape of the 'spike induced threshold'. """ ind_arr = np.zeros( (self.bin_count) ) for (t_ind,t_bin) in enumerate(self.t_bins): ind_arr[t_ind] = len( [t for t in self.t_hist if t <= t_bin] ) return ind_arr def _compute_bin_number(self,t): """ Which interval does t belong to? """ bin_number = 0 if t < self.t_bins[0]: return bin_number for t_cur in self.t_bins[1:]: bin_number += 1 if t < t_cur: break return bin_number
[docs] def reset(self,V=None): """ This method resets to neuron to a resting state. The spiking history is erased and the the voltage adaption currents are either erased or are set to the current value of the voltage itself. :keyword V: value to which voltage chasing currents are reset to """ self.t_hist = [] self.spike_prob = None self.X_arr = np.nan if (self.volt_adapt_time_const != None): if V != None: # if V is specified, we set all the adaptation currents to be exactly at V self.volt_adapt_currents = np.array([V for k in self.volt_adapt_time_const]) else: # just set to None, the _update_volt_dep_currents function will #initialize the currents when the time comes self.volt_adapt_currents = None
def _update_volt_dep_currents(self,V): """ Updates voltage chasing currents. These currents helps the threshold to incorporate voltage dependence. The degree to which these currents will contribute to the actual spike probability is determined by the threshold parameter array. """ #: if we don't even have time constants for volt adapt, then forget about it if self.volt_adapt_time_const != None: # if the following value isn't set, we need to initialize it if self.volt_adapt_currents == None: self.volt_adapt_currents = np.array([V for k in self.volt_adapt_time_const]) else: for (ind,k) in enumerate(self.volt_adapt_time_const): old_volt_adapt_val = self.volt_adapt_currents[ind] self.volt_adapt_currents[ind] = old_volt_adapt_val + self.dt * k * (V - old_volt_adapt_val) return self.volt_adapt_currents
[docs] def update_X_arr(self,V): """ Updates X_arr (the vector with which we do a dot product with the threshold parameters to calculate the spiking probability). """ # if the voltage is None, which means the neuron is spiking, # don't worry about calculating a spike probability if V == None or np.isnan(V): X_arr = np.array([np.nan] * self.param_ct) self.X_arr = X_arr return X_arr # we only update the spiking history if the neuron # is not currently spiking self._update_hist_by_dt() volt_adapt_arr = self._update_volt_dep_currents(V) ind_arr = self._compute_ind_arr() X_arr = np.concatenate(([V,1],ind_arr,volt_adapt_arr)) self.X_arr = X_arr return X_arr
def _update_hist_by_dt(self): """ Updates self.t_hist by dt. """ if self.t_hist: self.t_hist = [t+self.dt for t in self.t_hist] if self.t_hist[0] >= self.t_max: self.t_hist.pop(0)
[docs] def update(self,V): """ Updates inner state and returns True if there is a spike, and False if there is no spike. """ X_arr = self.update_X_arr(V) # if neuron is already spiking, then don't have the threshold # spike again in its refractory period if V == None or np.isnan(V): self.spike_prob = 0 return False if self.thresh_param == None: raise ValueError("Threshold parameter must be specified in self.thresh_param in order to compute spike probability!") hazard = exp(self.thresh_param.dot(X_arr)) # TO DO: replace this approximation with: #spike_prob = 1 - exp(-1.0 * hazard) self.spike_prob = hazard if random.random() < self.spike_prob : self._spike() return True else: return False
def _spike(self): # add a spike to spike history and set the voltage chasing currents to zero. # once the neuron is done spiking the voltage chasing currents will be # reset to the current voltage of the neuron, and then will be free to # chase the neuron's current voltage. self.t_hist.append(0.0) self.volt_adapt_currents = None
[docs]def max_likelihood(X_thresh_list,spike_ind_list,thresh_init,process_ct=None,iter_max=20,stopping_criteria=0.01): """ Performs maximum likelihood optimization. """ cur_param = thresh_init # we need to unpack some variables for better parallel consumption X_arr_cat = None prev_total_ind_ct = 0 for (ind,X_thresh) in enumerate(X_thresh_list): if X_arr_cat == None: X_arr_cat = X_thresh spike_ind_arr_cat = spike_ind_list[ind] else: X_arr_cat = np.vstack( (X_arr_cat,X_thresh) ) spike_ind_arr_cat = np.hstack( (spike_ind_arr_cat,spike_ind_list[ind] + prev_total_ind_ct) ) prev_total_ind_ct += len(X_thresh) # how many workers are going to be used if process_ct == None: process_ct = multiprocessing.cpu_count() #print "Process ct: " + str(process_ct) # splitting data into chuncks X_len = len(X_arr_cat) ind_vec = np.linspace(0, X_len, process_ct+1) ind_list = [int(this_ind) for this_ind in ind_vec[1:-1]] X_cat_split = np.split(X_arr_cat, ind_list,0) #print "Length of X_cat_split: " + str(len(X_cat_split)) print "Starting max likelihood optimization..." for _ in range(iter_max): [grad,cur_param] = par_calc_log_like_update(cur_param, X_cat_split,process_ct) grad_norm = np.linalg.norm(grad) #print "Norm of log likelihood gradient: " + str(grad_norm) if grad_norm < stopping_criteria: break return cur_param
[docs]def compute_log_likelihood(X_thresh_list,spike_ind_list,thresh_param): """ Computes log likelihood of the spike trains given the parameters. :param X_thresh_list: list of X_thresh matrices, one array for each sweep :param spike_ind_list: list of spike indices, one array for each sweep :param thresh_param: threshold parameter array :returns: log likelihood of the time series given parameter array """ log_sum = 0.0 for (ind,X_thresh) in enumerate(X_thresh_list): spike_ind = spike_ind_list[ind] for (t,X_thresh_val) in enumerate(X_thresh): if np.isnan(X_thresh_val[0]): continue if t+1 in spike_ind: log_sum += thresh_param.dot(X_thresh_val) else: log_sum -= exp(thresh_param.dot(X_thresh_val)) return log_sum
[docs]def estimate_thresh_parameters(subthresh_obj, thresh_obj, raw_data, process_ct=None, max_lik_iter_max=25, thresh_param_init=None, stopping_criteria=0.01): r""" Estimates threshold parameters that fit the raw data, and the particular form of the threshold and subthreshold components of the model. :param subthreshold_obj: :class:`fit_neuron.optimize.subthreshold.Voltage` instance :param thresh_obj: :class:`fit_neuron.optimize.threshold.StochasticThresh` instance :param raw_data: :class:`fit_neuron.data.my_data.RawData` instance :param make_lik_iter_max: maximum number of max likelihood iterations :param process_ct: number of processors we want to use to distribute max likelihood; if None then use CPU count. :param stopping_criteria: minimum :math:`L^2` norm of gradient of log likelihood gradient below which we stop optimization. :returns: array of threshold parameters """ print "Estimating threshold parameters..." param_ct = thresh_obj.param_ct if thresh_param_init == None: thresh_param_init = np.ones( (param_ct) ) * 0.001 X_thresh_list = [] spike_ind_list = data.extract_spikes.spk_from_bio(raw_data.membrane_voltage_list) for (ind,sweep) in enumerate(raw_data): subthresh_obj.reset() thresh_obj.reset() arr_len = len(sweep.membrane_voltage) spike_ind = spike_ind_list[ind] V = sweep.membrane_voltage[0] X_thresh = np.zeros( (arr_len,param_ct) ) for (t,Ie) in enumerate(sweep.input_current): V = subthresh_obj.update(V,Ie) new_X = thresh_obj.update_X_arr(V) X_thresh[t,:] = new_X # force the model to spike when the biological neuron spikes if t+1 in spike_ind: subthresh_obj.spike() thresh_obj._spike() X_thresh_list.append(X_thresh) thresh_param = max_likelihood(X_thresh_list, spike_ind_list, thresh_param_init, process_ct=process_ct, iter_max=max_lik_iter_max, stopping_criteria=stopping_criteria) log_l = compute_log_likelihood(X_thresh_list,spike_ind_list,thresh_param) print "Log likelihood: " + str(log_l) return thresh_param