Source code for pymodelfit.fitgui

#Copyright 2009 Erik Tollerud
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.

"""
This module contains the internals for the FitGui gui.
"""
#TODO: change single select to click-to-do-action

from __future__ import division,with_statement
import numpy as np
try: #this is the old-style import - below is for traits 4.x
    from enthought.traits.api import HasTraits,Instance,Int,Float,Bool,Button, \
                                     Event,Property,on_trait_change,Array,List, \
                                     Tuple,Str,Dict,cached_property,Color,Enum, \
                                     TraitError,Undefined,DelegatesTo
    from enthought.traits.ui.api import View,Handler,Item,Label,Group,VGroup, \
                                        HGroup, InstanceEditor,EnumEditor, \
                                        ListEditor, TupleEditor,spring
    from enthought.traits.ui.menu import ModalButtons
    from enthought.chaco.api import Plot,ArrayPlotData,jet,ColorBar,HPlotContainer,\
                                    ColorMapper,LinearMapper,ScatterInspectorOverlay,\
                                    LassoOverlay,AbstractOverlay,ErrorBarPlot, \
                                    ArrayDataSource
    from enthought.chaco.tools.api import PanTool,SelectTool,LassoSelection,ScatterInspector
    from enthought.enable.api import ColorTrait,ComponentEditor
    from enthought.enable.base_tool import KeySpec

    try:
        #I'm not certain when BetterSelectingZoom was implemented...
        from enthought.chaco.tools.api import BetterSelectingZoom as ZoomTool
    except ImportError:
        from enthought.chaco.tools.api import ZoomTool

except ImportError:
    from traits.api import HasTraits,Instance,Int,Float,Bool,Button, \
                                     Event,Property,on_trait_change,Array,List, \
                                     Tuple,Str,Dict,cached_property,Color,Enum, \
                                     TraitError,Undefined,DelegatesTo
    from traitsui.api import View,Handler,Item,Label,Group,VGroup, \
                                        HGroup, InstanceEditor,EnumEditor, \
                                        ListEditor, TupleEditor,spring
    from traitsui.menu import ModalButtons
    from chaco.api import Plot,ArrayPlotData,jet,ColorBar,HPlotContainer,\
                                    ColorMapper,LinearMapper,ScatterInspectorOverlay,\
                                    LassoOverlay,AbstractOverlay,ErrorBarPlot, \
                                    ArrayDataSource
    from chaco.tools.api import PanTool,SelectTool,LassoSelection,ScatterInspector
    from enable.api import ColorTrait,ComponentEditor
    from enable.base_tool import KeySpec

    try:
        #I'm not certain when BetterSelectingZoom was implemented...
        from chaco.tools.api import BetterSelectingZoom as ZoomTool
    except ImportError:
        from chaco.tools.api import ZoomTool




from .core import FunctionModel1D,list_models,get_model_class,get_model_instance
from .utils import binned_weights

class ColorMapperFixSingleVal(ColorMapper):
    coloratval = ColorTrait('black')
    val = 0

    def map_screen(self, data_array):
        res = super(ColorMapperFixSingleVal,self).map_screen(data_array)
        res[data_array==self.val] = self.coloratval_
        return res

#_cmap = jet
def _cmapblack(range, **traits):
    _data =   {'red':   ((0,1,1), (0.3, .8, .8), (0.5, 0, 0), (0.75,0.75, 0.75),(.875,.2,.2),
                         (1, 0, 0)),
               'blue': ((0., 0, 0), (0.3,0, 0), (0.5,0, 0), (0.75,.75, .75),
                         (0.875,0.75,0.75), (1, 1, 1)),
               'green':  ((0.,0, 0),(0.3,.8,.8), (0.4, 0.4, 0.4),(0.5,1,1), (0.65,.75, .75), (0.75,0.1, 0.1),
                         (1, 0, 0))}

    return ColorMapperFixSingleVal.from_segment_map(_data, range=range, **traits)

def _cmap(range, **traits):
    _data =   {'red':   ((0,1,1), (0.3, .8, .8), (0.5, 0, 0), (0.75,0.75, 0.75),(.875,.2,.2),
                         (1, 0, 0)),
               'blue': ((0., 0, 0), (0.3,0, 0), (0.5,0, 0), (0.75,.75, .75),
                         (0.875,0.75,0.75), (1, 1, 1)),
               'green':  ((0.,0, 0),(0.3,.8,.8), (0.4, 0.4, 0.4),(0.5,1,1), (0.65,.75, .75), (0.75,0.1, 0.1),
                         (1, 0, 0))}

#    """ inverted version of 'jet' colormap"""

#    _data =   {'red':   ((0., 0, 0), (0.35, 0, 0), (0.66, 1, 1), (0.89,1, 1),
#                         (1, 0.5, 0.5)),
#               'green': ((0., 0, 0), (0.125,0, 0), (0.375,1, 1), (0.64,1, 1),
#                         (0.91,0,0), (1, 0, 0)),
#               'blue':  ((0., 0.5, 0.5), (0.11, 1, 1), (0.34, 1, 1), (0.65,0, 0),
#                         (1, 0, 0))}
#    for k,v in _data.items():
#        _data[k] = tuple(reversed([(v[-1-i][0],t[1],t[2]) for i,t in enumerate(v)]))
#
    return ColorMapper.from_segment_map(_data, range=range, **traits)

class TraitedModel(HasTraits):
    from inspect import isclass

    model = Instance(FunctionModel1D,allow_none=True)
    modelname = Property(Str)
    updatetraitparams = Event
    paramchange = Event
    fitdata = Event
    fittype = Property(Str)
    fittypes = Property
    lastfitfailure = Instance(Exception,allow_none=True)

    def __init__(self,model,**traits):
        super(TraitedModel,self).__init__(**traits)

        from inspect import isclass

        if isinstance(model,basestring):
            model = get_model_instance(model)
        elif isclass(model):
            model = model()
        self.model = model

    def default_traits_view(self):
        if self.model is None:
            g = Group()
            g.content.append(Label('No Model Selected'))
        else:
            #g = Group(label=self.modelname,show_border=False,orientation='horizontal',layout='flow')
            g = Group(label=self.modelname,show_border=True,orientation='vertical')
            hg = HGroup(Item('fittype',label='Fit Technique',
                             editor=EnumEditor(name='fittypes')))
            g.content.append(hg)
            gp = HGroup(scrollable=True)
            for p in self.model.params:
                gi = Group(orientation='horizontal',label=p)
                self.add_trait(p,Float)
                setattr(self,p,getattr(self.model,p))
                self.on_trait_change(self._param_change_handler,p)
                gi.content.append(Item(p,show_label=False))

                ffp = 'fixfit_'+p
                self.add_trait(ffp,Bool)
                #default to fixed if the paramtere is a class-level fixed model
                setattr(self,ffp,p in self.model.__class__.fixedpars)
                self.on_trait_change(self._param_change_handler,ffp)
                gi.content.append(Item(ffp,label='Fix?'))

                gp.content.append(gi)
            g.content.append(gp)

        return View(g,buttons=['Apply','Revert','OK','Cancel'])

    def _param_change_handler(self,name,new):
        setattr(self.model,name,new)
        self.paramchange = name

    def _updatetraitparams_fired(self):
        m = self.model
        for p in m.params:
            setattr(self,p,getattr(m,p))
        self.paramchange = True

    def _fitdata_fired(self,new):
        from operator import isSequenceType,isMappingType

        if self.model is not None:
            if isSequenceType(new) and len(new) == 2:
                kw={'x':new[0],'y':new[1]}
            elif isSequenceType(new) and len(new) == 3:
                kw={'x':new[0],'y':new[1],'weights':new[2]}
            elif isMappingType(new):
                kw = dict(new)

                #add any missing pieces
                for i,k in enumerate(('x','y','weights')):
                    if k not in new:
                        if self.model.fiteddata:
                            new[k] = self.model.fiteddata[i]
                        else:
                            raise ValueError('no pre-fitted data available')
            elif new is True:
                if self.model.fiteddata:
                    fd = self.model.fiteddata
                    kw= {'x':fd[0],'y':fd[1],'weights':fd[2]}
                else:
                    raise ValueError('No data to fit')
            else:
                raise ValueError('unusable fitdata event input')

            if 'fixedpars' not in kw:
                 kw['fixedpars'] = [tn.replace('fixfit_','') for tn in self.traits() if tn.startswith('fixfit_') if getattr(self,tn)]
            try:
                self.model.fitData(**kw)
                self.updatetraitparams = True
                self.lastfitfailure = None
            except Exception,e:
                self.lastfitfailure = e

    def _get_modelname(self):
        if self.model is None:
            return 'None'
        else:
            return self.model.__class__.__name__

    def _get_fittype(self):
        if self.model is None:
            return None
        else:
            return self.model.fittype

    def _set_fittype(self,val):
        self.model.fittype = val

    def _get_fittypes(self):
        return self.model.fittypes

class NewModelSelector(HasTraits):
    modelnames = List
    selectedname = Str('No Model')
    modelargnum = Int(2)
    selectedmodelclass = Property
    isvarargmodel = Property(depends_on='modelnames')

    traits_view = View(Item('selectedname',label='Model Name:',editor=EnumEditor(name='modelnames')),
                       Item('modelargnum',label='Extra Parameters:',enabled_when='isvarargmodel'),
                       buttons=['OK','Cancel'])

    def __init__(self,include_models=None,exclude_models=None,**traits):
        super(NewModelSelector,self).__init__(**traits)

        self.modelnames = list_models(include_models,exclude_models,FunctionModel1D)
        self.modelnames.insert(0,'No Model')
        self.modelnames.sort()

    def _get_selectedmodelclass(self):
        n = self.selectedname
        if n == 'No Model':
            return None
        else:
            return get_model_class(n)

    def _get_isvarargmodel(self):
        cls = self.selectedmodelclass

        if cls is None:
            return False
        else:
            return cls.isVarnumModel()

#class WeightFillOverlay(AbstractOverlay):
#    weightval = Float(0)
#    color = ColorTrait('black')
#    plot = Instance(Plot)

#    def overlay(self, component, gc, view_bounds=None, mode="normal"):
#        from enthought.chaco.scatterplot import render_markers

#        plot = self.component
#        scatter = plot.plots['data'][0]
#        if not plot or not scatter or not scatter.index or not scatter.value:
#            return

#        w = plot.data.get_data('weights')
#        inds = w==self.weightval

#        index_data = scatter.index.get_data()
#        value_data = scatter.value.get_data()
#        screen_pts = scatter.map_screen(np.array([index_data[inds],value_data[inds]]).T)
#        screen_pts = screen_pts+[plot.x,plot.y]

#        props = ('line_width','marker_size','marker')
#        markerprops = dict([(prop,getattr(scatter,prop)) for prop in props])

#        markerprops['color']=self.color_
#        markerprops['outline_color']=self.color_

#        if markerprops.get('marker', None) == 'custom':
#            markerprops['custom_symbol'] = scatter.custom_symbol

#        gc.save_state()
#        gc.clip_to_rect(scatter.x+plot.x, scatter.y+plot.y, scatter.width, scatter.height)
#        render_markers(gc, screen_pts, **markerprops)
#        gc.restore_state()

class FGHandler(Handler):
#    def object_selbutton_changed(self,info):
#        info.object.edit_traits(parent=info.ui.control,view='selection_view')

    def object_datasymb_changed(self,info):
        kind = info.ui.rebuild.__name__.replace('ui_','') #TODO:not hack!
        info.object.plot.plots['data'][0].edit_traits(parent=info.ui.control,
                                                      kind=kind)

    def object_modline_changed(self,info):
        kind = info.ui.rebuild.__name__.replace('ui_','') #TODO:not hack!
        info.object.plot.plots['model'][0].edit_traits(parent=info.ui.control,
                                                       kind=kind)


[docs]class FitGui(HasTraits): """ This class represents the fitgui application state. """ plot = Instance(Plot) colorbar = Instance(ColorBar) plotcontainer = Instance(HPlotContainer) tmodel = Instance(TraitedModel,allow_none=False) nomodel = Property newmodel = Button('New Model...') fitmodel = Button('Fit Model') showerror = Button('Fit Error') updatemodelplot = Button('Update Model Plot') autoupdate = Bool(True) data = Array(dtype=float,shape=(2,None)) weights = Array weighttype = Enum(('custom','equal','lin bins','log bins')) weightsvary = Property(Bool) weights0rem = Bool(True) modelselector = NewModelSelector ytype = Enum(('data and model','residuals')) zoomtool = Instance(ZoomTool) pantool = Instance(PanTool) scattertool = Enum(None,'clicktoggle','clicksingle','clickimmediate','lassoadd','lassoremove','lassoinvert') selectedi = Property #indecies of the selected objects weightchangesel = Button('Set Selection To') weightchangeto = Float(1.0) delsel = Button('Delete Selected') unselectonaction = Bool(True) clearsel = Button('Clear Selections') lastselaction = Str('None') datasymb = Button('Data Symbol...') modline = Button('Model Line...') savews = Button('Save Weights') loadws = Button('Load Weights') _savedws = Array plotname = Property updatestats = Event chi2 = Property(Float,depends_on='updatestats') chi2r = Property(Float,depends_on='updatestats') nmod = Int(1024) #modelpanel = View(Label('empty'),kind='subpanel',title='model editor') modelpanel = View panel_view = View(VGroup( Item('plot', editor=ComponentEditor(),show_label=False), HGroup(Item('tmodel.modelname',show_label=False,style='readonly'), Item('nmod',label='Number of model points'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate'), Item('autoupdate',label='Auto?')) ), title='Model Data Fitter' ) selection_view = View(Group( Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False) ),title='Selection Options') traits_view = View(VGroup( HGroup(Item('object.plot.index_scale',label='x-scaling', enabled_when='object.plot.index_mapper.range.low>0 or object.plot.index_scale=="log"'), spring, Item('ytype',label='y-data'), Item('object.plot.value_scale',label='y-scaling', enabled_when='object.plot.value_mapper.range.low>0 or object.plot.value_scale=="log"') ), Item('plotcontainer', editor=ComponentEditor(),show_label=False), HGroup(VGroup(HGroup(Item('weighttype',label='Weights:'), Item('savews',show_label=False), Item('loadws',enabled_when='_savedws',show_label=False)), Item('weights0rem',label='Remove 0-weight points for fit?'), HGroup(Item('newmodel',show_label=False), Item('fitmodel',show_label=False), Item('showerror',show_label=False,enabled_when='tmodel.lastfitfailure'), VGroup(Item('chi2',label='Chi2:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None'), Item('chi2r',label='reduced:',style='readonly',format_str='%6.6g',visible_when='tmodel.model is not None')) )#Item('selbutton',show_label=False)) ,springy=False),spring, VGroup(HGroup(Item('autoupdate',label='Auto?'), Item('updatemodelplot',show_label=False,enabled_when='not autoupdate')), Item('nmod',label='Nmodel'), HGroup(Item('datasymb',show_label=False),Item('modline',show_label=False)),springy=False),springy=True), '_', HGroup(Item('scattertool',label='Selection Mode', editor=EnumEditor(values={None:'1:No Selection', 'clicktoggle':'3:Toggle Select', 'clicksingle':'2:Single Select', 'clickimmediate':'7:Immediate', 'lassoadd':'4:Add with Lasso', 'lassoremove':'5:Remove with Lasso', 'lassoinvert':'6:Invert with Lasso'})), Item('unselectonaction',label='Clear Selection on Action?'), Item('clearsel',show_label=False), Item('weightchangesel',show_label=False), Item('weightchangeto',label='Weight'), Item('delsel',show_label=False), ),#layout='flow'), Item('tmodel',show_label=False,style='custom',editor=InstanceEditor(kind='subpanel')) ), handler=FGHandler(), resizable=True, title='Data Fitting', buttons=['OK','Cancel'], width=700, height=900 ) def __init__(self,xdata=None,ydata=None,weights=None,model=None, include_models=None,exclude_models=None,fittype=None,**traits): """ :param xdata: the first dimension of the data to be fit :type xdata: array-like :param ydata: the second dimension of the data to be fit :type ydata: array-like :param weights: The weights to apply to the data. Statistically interpreted as inverse errors (*not* inverse variance). May be any of the following forms: * None for equal weights * an array of points that must match `ydata` * a 2-sequence of arrays (xierr,yierr) such that xierr matches the `xdata` and yierr matches `ydata` * a function called as f(params) that returns an array of weights that match one of the above two conditions :param model: the initial model to use to fit this data :type model: None, string, or :class:`pymodelfit.core.FunctionModel1D` instance. :param include_models: With `exclude_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param exclude_models: With `include_models`, specifies which models should be available in the "new model" dialog (see `models.list_models` for syntax). :param fittype: The fitting technique for the initial fit (see :class:`pymodelfit.core.FunctionModel`). :type fittype: string kwargs are passed in as any additional traits to apply to the application. """ self.modelpanel = View(Label('empty'),kind='subpanel',title='model editor') self.tmodel = TraitedModel(model) if model is not None and fittype is not None: self.tmodel.model.fittype = fittype if xdata is None or ydata is None: if not hasattr(self.tmodel.model,'data') or self.tmodel.model.data is None: raise ValueError('data not provided and no data in model') if xdata is None: xdata = self.tmodel.model.data[0] if ydata is None: ydata = self.tmodel.model.data[1] if weights is None: weights = self.tmodel.model.data[2] self.on_trait_change(self._paramsChanged,'tmodel.paramchange') self.modelselector = NewModelSelector(include_models,exclude_models) self.data = [xdata,ydata] if weights is None: self.weights = np.ones_like(xdata) self.weighttype = 'equal' else: self.weights = np.array(weights,copy=True) self.savews = True weights1d = self.weights while len(weights1d.shape)>1: weights1d = np.sum(weights1d**2,axis=0) pd = ArrayPlotData(xdata=self.data[0],ydata=self.data[1],weights=weights1d) self.plot = plot = Plot(pd,resizable='hv') self.scatter = plot.plot(('xdata','ydata','weights'),name='data', color_mapper=_cmapblack if self.weights0rem else _cmap, type='cmap_scatter', marker='circle')[0] self.errorplots = None if not isinstance(model,FunctionModel1D): self.fitmodel = True self.updatemodelplot = False #force plot update - generates xmod and ymod plot.plot(('xmod','ymod'),name='model',type='line',line_style='dash',color='black',line_width=2) del plot.x_mapper.range.sources[-1] #remove the line plot from the x_mapper source so only the data is tied to the scaling self.on_trait_change(self._rangeChanged,'plot.index_mapper.range.updated') self.pantool = PanTool(plot,drag_button='left') plot.tools.append(self.pantool) self.zoomtool = ZoomTool(plot) self.zoomtool.prev_state_key = KeySpec('a') self.zoomtool.next_state_key = KeySpec('s') plot.overlays.append(self.zoomtool) self.scattertool = None self.scatter.overlays.append(ScatterInspectorOverlay(self.scatter, hover_color = "black", selection_color="black", selection_outline_color="red", selection_line_width=2)) self.colorbar = colorbar = ColorBar(index_mapper=LinearMapper(range=plot.color_mapper.range), color_mapper=plot.color_mapper.range, plot=plot, orientation='v', resizable='v', width = 30, padding = 5) colorbar.padding_top = plot.padding_top colorbar.padding_bottom = plot.padding_bottom colorbar._axis.title = 'Weights' self.plotcontainer = container = HPlotContainer(use_backbuffer=True) container.add(plot) container.add(colorbar) super(FitGui,self).__init__(**traits) self.on_trait_change(self._scale_change,'plot.value_scale,plot.index_scale') if weights is not None and len(weights)==2: self.weightsChanged() #update error bars def _weights0rem_changed(self,old,new): if new: self.plot.color_mapper = _cmapblack(self.plot.color_mapper.range) else: self.plot.color_mapper = _cmap(self.plot.color_mapper.range) self.plot.request_redraw() # if old and self.filloverlay in self.plot.overlays: # self.plot.overlays.remove(self.filloverlay) # if new: # self.plot.overlays.append(self.filloverlay) # self.plot.request_redraw() def _paramsChanged(self): self.updatemodelplot = True def _nmod_changed(self): self.updatemodelplot = True def _rangeChanged(self): self.updatemodelplot = True #@on_trait_change('object.plot.value_scale,object.plot.index_scale',post_init=True) def _scale_change(self): self.plot.request_redraw() def _updatemodelplot_fired(self,new): #If the plot has not been generated yet, just skip the update if self.plot is None: return #if False (e.g. button click), update regardless, otherwise check for autoupdate if new and not self.autoupdate: return mod = self.tmodel.model if self.ytype == 'data and model': if mod: #xd = self.data[0] #xmod = np.linspace(np.min(xd),np.max(xd),self.nmod) xl = self.plot.index_range.low xh = self.plot.index_range.high if self.plot.index_scale=="log": xmod = np.logspace(np.log10(xl),np.log10(xh),self.nmod) else: xmod = np.linspace(xl,xh,self.nmod) ymod = self.tmodel.model(xmod) self.plot.data.set_data('xmod',xmod) self.plot.data.set_data('ymod',ymod) else: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) elif self.ytype == 'residuals': if mod: self.plot.data.set_data('xmod',[]) self.plot.data.set_data('ymod',[]) #residuals set the ydata instead of setting the model res = mod.residuals(*self.data) self.plot.data.set_data('ydata',res) else: self.ytype = 'data and model' else: assert True,'invalid Enum' def _fitmodel_fired(self): from warnings import warn preaup = self.autoupdate try: self.autoupdate = False xd,yd = self.data kwd = {'x':xd,'y':yd} if self.weights is not None: w = self.weights if self.weights0rem: if xd.shape == w.shape: m = w!=0 w = w[m] kwd['x'] = kwd['x'][m] kwd['y'] = kwd['y'][m] elif np.any(w==0): warn("can't remove 0-weighted points if weights don't match data") kwd['weights'] = w self.tmodel.fitdata = kwd finally: self.autoupdate = preaup self.updatemodelplot = True self.updatestats = True # def _tmodel_changed(self,old,new): # #old is only None before it is initialized # if new is not None and old is not None and new.model is not None: # self.fitmodel = True def _newmodel_fired(self,newval): from inspect import isclass if isinstance(newval,basestring) or isinstance(newval,FunctionModel1D) \ or (isclass(newval) and issubclass(newval,FunctionModel1D)): self.tmodel = TraitedModel(newval) else: if self.modelselector.edit_traits(kind='modal').result: cls = self.modelselector.selectedmodelclass if cls is None: self.tmodel = TraitedModel(None) elif self.modelselector.isvarargmodel: self.tmodel = TraitedModel(cls(self.modelselector.modelargnum)) self.fitmodel = True else: self.tmodel = TraitedModel(cls()) self.fitmodel = True else: #cancelled return def _showerror_fired(self,evt): if self.tmodel.lastfitfailure: ex = self.tmodel.lastfitfailure dialog = HasTraits(s=ex.__class__.__name__+': '+str(ex)) view = View(Item('s',style='custom',show_label=False), resizable=True,buttons=['OK'],title='Fitting error message') dialog.edit_traits(view=view) @cached_property def _get_chi2(self): try: return self.tmodel.model.chi2Data()[0] except: return 0 @cached_property def _get_chi2r(self): try: return self.tmodel.model.chi2Data()[1] except: return 0 def _get_nomodel(self): return self.tmodel.model is None def _get_weightsvary(self): w = self.weights return np.any(w!=w[0])if len(w)>0 else False def _get_plotname(self): xlabel = self.plot.x_axis.title ylabel = self.plot.y_axis.title if xlabel == '' and ylabel == '': return '' else: return xlabel+' vs '+ylabel def _set_plotname(self,val): if isinstance(val,basestring): val = val.split('vs') if len(val) ==1: val = val.split('-') val = [v.strip() for v in val] self.x_axis.title = val[0] self.y_axis.title = val[1] #selection-related def _scattertool_changed(self,old,new): if new == 'No Selection': self.plot.tools[0].drag_button='left' else: self.plot.tools[0].drag_button='right' if old is not None and 'lasso' in old: if new is not None and 'lasso' in new: #connect correct callbacks self.lassomode = new.replace('lasso','') return else: #TODO:test self.scatter.tools[-1].on_trait_change(self._lasso_handler, 'selection_changed',remove=True) del self.scatter.overlays[-1] del self.lassomode elif old == 'clickimmediate': self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed',remove=True) self.scatter.tools = [] if new is None: pass elif 'click' in new: smodemap = {'clickimmediate':'single','clicksingle':'single', 'clicktoggle':'toggle'} self.scatter.tools.append(ScatterInspector(self.scatter, selection_mode=smodemap[new])) if new == 'clickimmediate': self.clearsel = True self.scatter.index.on_trait_change(self._immediate_handler, 'metadata_changed') elif 'lasso' in new: lasso_selection = LassoSelection(component=self.scatter, selection_datasource=self.scatter.index) self.scatter.tools.append(lasso_selection) lasso_overlay = LassoOverlay(lasso_selection=lasso_selection, component=self.scatter) self.scatter.overlays.append(lasso_overlay) self.lassomode = new.replace('lasso','') lasso_selection.on_trait_change(self._lasso_handler, 'selection_changed') lasso_selection.on_trait_change(self._lasso_handler, 'selection_completed') lasso_selection.on_trait_change(self._lasso_handler, 'updated') else: raise TraitsError('invalid scattertool value') def _weightchangesel_fired(self): self.weights[self.selectedi] = self.weightchangeto if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'weightchangesel' def _delsel_fired(self): self.weights[self.selectedi] = 0 if self.unselectonaction: self.clearsel = True self._sel_alter_weights() self.lastselaction = 'delsel' def _sel_alter_weights(self): if self.weighttype != 'custom': self._customweights = self.weights self.weighttype = 'custom' self.weightsChanged() def _clearsel_fired(self,event): if isinstance(event,list): self.scatter.index.metadata['selections'] = event else: self.scatter.index.metadata['selections'] = list() def _lasso_handler(self,name,new): if name == 'selection_changed': lassomask = self.scatter.index.metadata['selection'].astype(int) clickmask = np.zeros_like(lassomask) clickmask[self.scatter.index.metadata['selections']] = 1 if self.lassomode == 'add': mask = clickmask | lassomask elif self.lassomode == 'remove': mask = clickmask & ~lassomask elif self.lassomode == 'invert': mask = np.logical_xor(clickmask,lassomask) else: raise TraitsError('lassomode is in invalid state') self.scatter.index.metadata['selections'] = list(np.where(mask)[0]) elif name == 'selection_completed': self.scatter.overlays[-1].visible = False elif name == 'updated': self.scatter.overlays[-1].visible = True else: raise ValueError('traits event name %s invalid'%name) def _immediate_handler(self): sel = self.selectedi if len(sel) > 1: self.clearsel = True raise TraitsError('selection error in immediate mode - more than 1 selection') elif len(sel)==1: if self.lastselaction != 'None': setattr(self,self.lastselaction,True) del sel[0] def _savews_fired(self): self._savedws = self.weights.copy() def _loadws_fired(self): self.weights = self._savedws self._savews_fired() def _get_selectedi(self): return self.scatter.index.metadata['selections'] @on_trait_change('data,ytype',post_init=True)
[docs] def dataChanged(self): """ Updates the application state if the fit data are altered - the GUI will know if you give it a new data array, but not if the data is changed in-place. """ pd = self.plot.data #TODO:make set_data apply to both simultaneously? pd.set_data('xdata',self.data[0]) pd.set_data('ydata',self.data[1]) self.updatemodelplot = False
@on_trait_change('weights',post_init=True)
[docs] def weightsChanged(self): """ Updates the application state if the weights/error bars for this model are changed - the GUI will automatically do this if you give it a new set of weights array, but not if they are changed in-place. """ weights = self.weights if 'errorplots' in self.trait_names(): #TODO:switch this to updating error bar data/visibility changing if self.errorplots is not None: self.plot.remove(self.errorplots[0]) self.plot.remove(self.errorplots[1]) self.errorbarplots = None if len(weights.shape)==2 and weights.shape[0]==2: xerr,yerr = 1/weights high = ArrayDataSource(self.scatter.index.get_data()+xerr) low = ArrayDataSource(self.scatter.index.get_data()-xerr) ebpx = ErrorBarPlot(orientation='v', value_high = high, value_low = low, index = self.scatter.value, value = self.scatter.index, index_mapper = self.scatter.value_mapper, value_mapper = self.scatter.index_mapper ) self.plot.add(ebpx) high = ArrayDataSource(self.scatter.value.get_data()+yerr) low = ArrayDataSource(self.scatter.value.get_data()-yerr) ebpy = ErrorBarPlot(value_high = high, value_low = low, index = self.scatter.index, value = self.scatter.value, index_mapper = self.scatter.index_mapper, value_mapper = self.scatter.value_mapper ) self.plot.add(ebpy) self.errorplots = (ebpx,ebpy) while len(weights.shape)>1: weights = np.sum(weights**2,axis=0) self.plot.data.set_data('weights',weights) self.plot.plots['data'][0].color_mapper.range.refresh() if self.weightsvary: if self.colorbar not in self.plotcontainer.components: self.plotcontainer.add(self.colorbar) self.plotcontainer.request_redraw() elif self.colorbar in self.plotcontainer.components: self.plotcontainer.remove(self.colorbar) self.plotcontainer.request_redraw()
def _weighttype_changed(self, name, old, new): if old == 'custom': self._customweights = self.weights if new == 'custom': self.weights = self._customweights #if hasattr(self,'_customweights') else np.ones_like(self.data[0]) elif new == 'equal': self.weights = np.ones_like(self.data[0]) elif new == 'lin bins': self.weights = binned_weights(self.data[0],10,False) elif new == 'log bins': self.weights = binned_weights(self.data[0],10,True) else: raise TraitError('Invalid Enum value on weighttype')
[docs] def getModelInitStr(self): """ Generates a python code string that can be used to generate a model with parameters matching the model in this :class:`FitGui`. :returns: initializer string """ mod = self.tmodel.model if mod is None: return 'None' else: parstrs = [] for p,v in mod.pardict.iteritems(): parstrs.append(p+'='+str(v)) if mod.__class__._pars is None: #varargs need to have the first argument give the right number varcount = len(mod.params)-len(mod.__class__._statargs) parstrs.insert(0,str(varcount)) return '%s(%s)'%(mod.__class__.__name__,','.join(parstrs))
[docs] def getModelObject(self): """ Gets the underlying object representing the model for this fit. :returns: The :class:`pymodelfit.core.FunctionModel1D` object. """ return self.tmodel.model
[docs]def fit_data(*args,**kwargs): """ Fit a 2d data set using the :class:`FitGui` interface. A GUI application instance must already exist (e.g. interactive mode of ipython). This function is modal and will block until the user hits "OK" or "Cancel" - if non-blocking behavior is desired, create a :class:`FitGui` object and call :meth:`FitGui.edit_traits`. The following forms for input arguments are accepted: * fit_data(xdata,ydata) * fit_data(xdata,ydata,model) * fit_data(model) This form requires a :class:`FunctionModel1D` object that includes data :param xdata: the first dimension of the data to be fit :type xdata: array-like :param ydata: the second dimension of the data to be fit :type ydata: array-like :param model: the initial model to use to fit this data :type model: None, string, or :class:`pymodelfit.core.FunctionModel1D` instance kwargs are passed into the fitgui initializer :returns: The model or None if fitting is cancelled or no model is assigned in the GUI. **Examples** >>> from numpy.random import randn >>> fit_data(randn(100),randn(100)) #doctest: +SKIP This will bring up 100 normally-distributed points with no initial fitting model. >>> from numpy.random import randn >>> fit_data(randn(100),randn(100),'linear') #doctest: +SKIP This will bring up 100 normally-distributed points with a best-fit linear model. >>> from numpy.random import randn >>> fit_data(randn(100),randn(100),'linear',weights=rand(100)) #doctest: +SKIP This will bring up 100 normally-distributed points with a best-fit linear model with the points weighted by uniform random values. >>> from numpy import tile >>> from numpy.random import randn,rand >>> fit_data(randn(100),randn(100),'linear',weights=tile(rand(100),2).reshape((2,10)),fittype='yerr') #doctest: +SKIP This will bring up 100 normally-distributed points with a linear model with the points weighted by a uniform random number (interpreted as inverse error) fit using the yerr algorithm instead of the default least-squares. """ kwargs = dict(kwargs) #copy if len(args) == 2: xdata = args[0] ydata = args[1] kwargs.setdefault('model',None) elif len(args) == 3: xdata = args[0] ydata = args[1] if 'model' in kwargs: raise TypeError("got two values for 'model' argument") kwargs['model'] = args[2] elif len(args) == 1: xdata = ydata = None if 'model' in kwargs: raise TypeError("got two values for 'model' argument") kwargs['model'] = args[0] if kwargs['model'].data is None: raise ValueError('cannot fit_data for a model with no data') else: raise TypeError('fit_data takes 1,2, or 3 arguments (%i given)'%len(args)) model = kwargs['model'] fg = FitGui(xdata,ydata,**kwargs) if model is not None and not isinstance(model,FunctionModel1D): fg.fitmodel = True res = fg.edit_traits(kind='livemodal') if res: return fg.getModelObject() else: return None