Source code for pymodelfit.multifitgui

#Copyright 2009 Erik Tollerud
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

"""
This module contains the internals for the FitGui gui.
"""
#TODO: change single select to click-to-do-action

from __future__ import division,with_statement
import numpy as np

try: #this is the old-style import - below is for traits 4.x
    from enthought.traits.api import HasTraits,Bool,Button,Float,Int,Color, \
                                     Instance,Tuple,Array,List,Dict,Str,Property, \
                                    on_trait_change,cached_property
    from enthought.traits.ui.api import View,VGroup,HGroup,Item,TupleEditor, \
                                        ListEditor

    from enthought.tvtk.pyface.scene_editor import SceneEditor
    from enthought.mayavi.tools.mlab_scene_model import MlabSceneModel
    from enthought.mayavi.core.ui.mayavi_scene import MayaviScene
except ImportError:

    from traits.api import HasTraits,Bool,Button,Float,Int,Color, \
                                     Instance,Tuple,Array,List,Dict,Str,Property, \
                                    on_trait_change,cached_property
    from traitsui.api import View,VGroup,HGroup,Item,TupleEditor, \
                                        ListEditor

    from tvtk.pyface.scene_editor import SceneEditor
    from mayavi.tools.mlab_scene_model import MlabSceneModel
    from mayavi.core.ui.mayavi_scene import MayaviScene


from .fitgui import FitGui,TraitedModel

[docs]class MultiFitGui(HasTraits): """ data should be c x N where c is the number of data columns/axes and N is the number of points """ doplot3d = Bool(False) show3d = Button('Show 3D Plot') replot3d = Button('Replot 3D') scalefactor3d = Float(0) do3dscale = Bool(False) nmodel3d = Int(1024) usecolor3d = Bool(False) color3d = Color((0,0,0)) scene3d = Instance(MlabSceneModel,()) plot3daxes = Tuple(('x','y','z')) data = Array(shape=(None,None)) weights = Array(shape=(None,)) curveaxes = List(Tuple(Int,Int)) axisnames = Dict(Int,Str) invaxisnames = Property(Dict,depends_on='axisnames') fgs = List(Instance(FitGui)) traits_view = View(VGroup(Item('fgs',editor=ListEditor(use_notebook=True,page_name='.plotname'),style='custom',show_label=False), Item('show3d',show_label=False)), resizable=True,height=900,buttons=['OK','Cancel'],title='Multiple Model Data Fitters') plot3d_view = View(VGroup(Item('scene3d',editor=SceneEditor(scene_class=MayaviScene),show_label=False,resizable=True), Item('plot3daxes',editor=TupleEditor(cols=3,labels=['x','y','z']),label='Axes'), HGroup(Item('do3dscale',label='Scale by weight?'), Item('scalefactor3d',label='Point scale'), Item('nmodel3d',label='Nmodel')), HGroup(Item('usecolor3d',label='Use color?'),Item('color3d',label='Relation Color',enabled_when='usecolor3d')), Item('replot3d',show_label=False),springy=True), resizable=True,height=800,width=800,title='Multiple Model3D Plot') def __init__(self,data,names=None,models=None,weights=None,dofits=True,**traits): """ :param data: The data arrays :type data: sequence of c equal-length arrays (length N) :param names: Names :type names: sequence of strings, length c :param models: The models to fit for each pair either as strings or :class:`astroypsics.models.ParametricModel` objects. :type models: sequence of models, length c-1 :param weights: the weights for each point or None for no weights :type weights: array-like of size N or None :param dofits: If True, the data will be fit to the models when the object is created, otherwise the models will be passed in as-is (or as created). :type dofits: bool extra keyword arguments get passed in as new traits (r[finmask],m[finmask],l[finmask]),names='rh,Mh,Lh',weights=w[finmask],models=models,dofits=False) """ super(MultiFitGui,self).__init__(**traits) self._lastcurveaxes = None data = np.array(data,copy=False) if weights is None: self.weights = np.ones(data.shape[1]) else: self.weights = np.array(weights) self.data = data if data.shape[0] < 2: raise ValueError('Must have at least 2 columns') if isinstance(names,basestring): names = names.split(',') if names is None: if len(data) == 2: self.axisnames = {0:'x',1:'y'} elif len(data) == 3: self.axisnames = {0:'x',1:'y',2:'z'} else: self.axisnames = dict((i,str(i)) for i in data) elif len(names) == len(data): self.axisnames = dict([t for t in enumerate(names)]) else: raise ValueError("names don't match data") #default to using 0th axis as parametric self.curveaxes = [(0,i) for i in range(len(data))[1:]] if models is not None: if len(models) != len(data)-1: raise ValueError("models don't match data") for i,m in enumerate(models): fg = self.fgs[i] newtmodel = TraitedModel(m) if dofits: fg.tmodel = newtmodel fg.fitmodel = True #should happen automatically, but this makes sure else: oldpard = newtmodel.model.pardict fg.tmodel = newtmodel fg.tmodel .model.pardict = oldpard if dofits: fg.fitmodel = True def _data_changed(self): self.curveaxes = [(0,i) for i in range(len(self.data))[1:]] def _axisnames_changed(self): for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self.plot3daxes = (self.axisnames[0],self.axisnames[1],self.axisnames[2] if len(self.axisnames) > 2 else self.axisnames[1]) @on_trait_change('curveaxes[]') def _curveaxes_update(self,names,old,new): ax=[] for t in self.curveaxes: ax.append(t[0]) ax.append(t[1]) if set(ax) != set(range(len(self.data))): self.curveaxes = self._lastcurveaxes return #TOOD:check for recursion if self._lastcurveaxes is None: self.fgs = [FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) for t in self.curveaxes] for ax,fg in zip(self.curveaxes,self.fgs): fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' else: for i,t in enumerate(self.curveaxes): if self._lastcurveaxes[i] != t: self.fgs[i] = fg = FitGui(self.data[t[0]],self.data[t[1]],weights=self.weights) ax = self.curveaxes[i] fg.plot.x_axis.title = self.axisnames[ax[0]] if ax[0] in self.axisnames else '' fg.plot.y_axis.title = self.axisnames[ax[1]] if ax[1] in self.axisnames else '' self._lastcurveaxes = self.curveaxes def _show3d_fired(self): self.edit_traits(view='plot3d_view') self.doplot3d = True self.replot3d = True def _plot3daxes_changed(self): self.replot3d = True @on_trait_change('weights',post_init=True)
[docs] def weightsChanged(self): for fg in self.fgs: if fg.weighttype != 'custom': fg.weighttype = 'custom' fg.weights = self.weights
@on_trait_change('data','fgs','replot3d','weights') def _do_3d(self): if self.doplot3d: M = self.scene3d.mlab try: xi = self.invaxisnames[self.plot3daxes[0]] yi = self.invaxisnames[self.plot3daxes[1]] zi = self.invaxisnames[self.plot3daxes[2]] x,y,z = self.data[xi],self.data[yi],self.data[zi] w = self.weights M.clf() if self.scalefactor3d == 0: sf = x.max()-x.min() sf *= y.max()-y.min() sf *= z.max()-z.min() sf = sf/len(x)/5 self.scalefactor3d = sf else: sf = self.scalefactor3d glyph = M.points3d(x,y,z,w,scale_factor=sf) glyph.glyph.scale_mode = 0 if self.do3dscale else 1 M.axes(xlabel=self.plot3daxes[0],ylabel=self.plot3daxes[1],zlabel=self.plot3daxes[2]) try: xs = np.linspace(np.min(x),np.max(x),self.nmodel3d) #find sequence of models to go from x to y and z ymods,zmods = [],[] for curri,mods in zip((yi,zi),(ymods,zmods)): while curri != xi: for i,(i1,i2) in enumerate(self.curveaxes): if curri==i2: curri = i1 mods.insert(0,self.fgs[i].tmodel.model) break else: raise KeyError ys = xs for m in ymods: ys = m(ys) zs = xs for m in zmods: zs = m(zs) if self.usecolor3d: c = (self.color3d[0]/255,self.color3d[1]/255,self.color3d[2]/255) M.plot3d(xs,ys,zs,color=c) else: M.plot3d(xs,ys,zs,np.arange(len(xs))) except (KeyError,TypeError): M.text(0.5,0.75,'Underivable relation') except KeyError: M.clf() M.text(0.25,0.25,'Data problem') @cached_property def _get_invaxisnames(self): d={} for k,v in self.axisnames.iteritems(): d[v] = k return d
[docs]def fit_data_multi(data,names=None,weights=None,models=None): """ fit a data set consisting of a variety of curves simultaneously. A GUI application instance must already exist (e.g. interactive mode of ipython) returns a tuple of models e.g. [xvsy,xvsz] """ if len(data.shape) !=2 or data.shape[0]<2: raise ValueError('data must be 2D with first dimension >=2') if models is not None and len(models) != data.shape[0]: raise ValueError('Number of models does not match number of data sets') mfg = MultiFitGui(data,names,models,weights=weights) res = mfg.edit_traits(kind='livemodal') if res: return tuple([fg.tmodel.model for fg in mfg.fgs]) else: return None