Source code for plotpy.widgets.fit

# -*- coding: utf-8 -*-
#
# Copyright © 2009-2010 CEA
# Pierre Raybaut
# Licensed under the terms of the CECILL License
# (see plotpy/__init__.py for details)

# pylint: disable=C0103

"""
plotpy.widgets.fit
------------------

The `fit` module provides an interactive curve fitting widget/dialog allowing:
    * to fit data manually (by moving sliders)
    * or automatically (with standard optimization algorithms 
      provided by :py:mod:`scipy`).

Example
~~~~~~~

.. literalinclude:: /../plotpy/tests/fit.py
   :start-after: SHOW
   :end-before: Workaround for Sphinx v0.6 bug: empty 'end-before' directive

.. image:: /images/screenshots/fit.png

Reference
~~~~~~~~~

.. autofunction:: guifit

.. autoclass:: FitDialog
   :members:
   :inherited-members:
.. autoclass:: FitParam
   :members:
   :inherited-members:
.. autoclass:: AutoFitParam
   :members:
   :inherited-members:
"""

from __future__ import division, print_function

from guidata.qt.QtGui import (QGridLayout, QLabel, QSlider, QPushButton,
                              QLineEdit, QDialog, QVBoxLayout, QHBoxLayout,
                              QWidget, QDialogButtonBox)
from guidata.qt.QtCore import Qt
from guidata.qt import PYQT5

import numpy as np
from numpy import inf # Do not remove this import (used by optimization funcs)

import guidata
from guidata.utils import update_dataset, restore_dataset
from guidata.qthelpers import create_groupbox
from guidata.configtools import get_icon
from guidata.dataset.datatypes import DataSet
from guidata.dataset.dataitems import (StringItem, FloatItem, IntItem,
                                       ChoiceItem, BoolItem)

# Local imports
from plotpy.config import _
from plotpy.builder import make
from plotpy.plot import CurveWidgetMixin

[docs]class AutoFitParam(DataSet): xmin = FloatItem("xmin") xmax = FloatItem("xmax") method = ChoiceItem(_("Method"), [ ("simplex", "Simplex"), ("powel", "Powel"), ("bfgs", "BFGS"), ("l_bfgs_b", "L-BFGS-B"), ("cg", _("Conjugate Gradient")), ("lq", _("Least squares")), ], default="lq") err_norm = StringItem("enorm", default=2.0, help=_("for simplex, powel, cg and bfgs norm used " "by the error function")) xtol = FloatItem("xtol", default=0.0001, help=_("for simplex, powel, least squares")) ftol = FloatItem("ftol", default=0.0001, help=_("for simplex, powel, least squares")) gtol = FloatItem("gtol", default=0.0001, help=_("for cg, bfgs")) norm = StringItem("norm", default="inf", help=_("for cg, bfgs. inf is max, -inf is min"))
class FitParamDataSet(DataSet): name = StringItem(_("Name")) value = FloatItem(_("Value"), default=0.0) min = FloatItem(_("Min"), default=-1.0) max = FloatItem(_("Max"), default=1.0).set_pos(col=1) steps = IntItem(_("Steps"), default=5000) format = StringItem(_("Format"), default="%.3f").set_pos(col=1) logscale = BoolItem(_("Logarithmic"), _("Scale")) unit = StringItem(_("Unit"), default="").set_pos(col=1)
[docs]class FitParam(object): def __init__(self, name, value, min, max, logscale=False, steps=5000, format='%.3f', size_offset=0, unit=''): self.name = name self.value = value self.min = min self.max = max self.logscale = logscale self.steps = steps self.format = format self.unit = unit self.prefix_label = None self.lineedit = None self.unit_label = None self.slider = None self.button = None self._widgets = [] self._size_offset = size_offset self._refresh_callback = None self.dataset = FitParamDataSet(title=_("Curve fitting parameter"))
[docs] def copy(self): """Return a copy of this fitparam""" return self.__class__(self.name, self.value, self.min, self.max, self.logscale, self.steps, self.format, self._size_offset, self.unit)
def create_widgets(self, parent, refresh_callback): self._refresh_callback = refresh_callback self.prefix_label = QLabel() font = self.prefix_label.font() font.setPointSize(font.pointSize()+self._size_offset) self.prefix_label.setFont(font) self.button = QPushButton() self.button.setIcon(get_icon('settings.png')) self.button.setToolTip( _("Edit '%s' fit parameter properties") % self.name) self.button.clicked.connect(lambda: self.edit_param(parent)) self.lineedit = QLineEdit() self.lineedit.editingFinished.connect(self.line_editing_finished) self.unit_label = QLabel(self.unit) self.slider = QSlider() self.slider.setOrientation(Qt.Horizontal) self.slider.setRange(0, self.steps-1) self.slider.valueChanged.connect(self.slider_value_changed) self.update(refresh=False) self.add_widgets([self.prefix_label, self.lineedit, self.unit_label, self.slider, self.button]) def add_widgets(self, widgets): self._widgets += widgets def get_widgets(self): return self._widgets def set_scale(self, state): self.logscale = state > 0 self.update_slider_value() def set_text(self, fmt=None): style = "<span style=\'color: #444444\'><b>%s</b></span>" self.prefix_label.setText(style % self.name) if self.value is None: value_str = '' else: if fmt is None: fmt = self.format value_str = fmt % self.value self.lineedit.setText(value_str) self.lineedit.setDisabled( self.value == self.min and self.max == self.min) def line_editing_finished(self): try: self.value = float(self.lineedit.text()) except ValueError: self.set_text() self.update_slider_value() self._refresh_callback() def slider_value_changed(self, int_value): if self.logscale: total_delta = np.log10(1+self.max-self.min) self.value = self.min+10**(total_delta*int_value/(self.steps-1))-1 else: total_delta = self.max-self.min self.value = self.min+total_delta*int_value/(self.steps-1) self.set_text() self._refresh_callback() def update_slider_value(self): if (self.value is None or self.min is None or self.max is None): self.slider.setEnabled(False) if self.slider.parent() and self.slider.parent().isVisible(): self.slider.show() elif self.value == self.min and self.max == self.min: self.slider.hide() else: self.slider.setEnabled(True) if self.slider.parent() and self.slider.parent().isVisible(): self.slider.show() if self.logscale: value_delta = max([np.log10(1+self.value-self.min), 0.]) total_delta = np.log10(1+self.max-self.min) else: value_delta = self.value-self.min total_delta = self.max-self.min intval = int(self.steps*value_delta/total_delta) self.slider.blockSignals(True) self.slider.setValue(intval) self.slider.blockSignals(False) def edit_param(self, parent): update_dataset(self.dataset, self) if self.dataset.edit(parent=parent): restore_dataset(self.dataset, self) if self.value > self.max: self.max = self.value if self.value < self.min: self.min = self.value self.update() def update(self, refresh=True): self.unit_label.setText(self.unit) self.slider.setRange(0, self.steps-1) self.update_slider_value() self.set_text() if refresh: self._refresh_callback()
def add_fitparam_widgets_to(layout, fitparams, refresh_callback, param_cols=1): row_contents = [] row_nb = 0 col_nb = 0 for i, param in enumerate(fitparams): param.create_widgets(layout.parent(), refresh_callback) widgets = param.get_widgets() w_colums = len(widgets)+1 row_contents += [(widget, row_nb, j+col_nb*w_colums) for j, widget in enumerate(widgets)] col_nb += 1 if col_nb == param_cols: row_nb += 1 col_nb = 0 for widget, row, col in row_contents: layout.addWidget(widget, row, col) if fitparams: for col_nb in range(param_cols): layout.setColumnStretch(1+col_nb*w_colums, 5) if col_nb > 0: layout.setColumnStretch(col_nb*w_colums-1, 1) class FitWidgetMixin(CurveWidgetMixin): def __init__(self, wintitle="plotpy plot", icon="plotpy.svg", toolbar=False, options=None, panels=None, param_cols=1, legend_anchor='TR', auto_fit=True): if wintitle is None: wintitle = _('Curve fitting') self.x = None self.y = None self.fitfunc = None self.fitargs = None self.fitkwargs = None self.fitparams = None self.autofit_prm = None self.data_curve = None self.fit_curve = None self.legend = None self.legend_anchor = legend_anchor self.xrange = None self.show_xrange = False self.param_cols = param_cols self.auto_fit_enabled = auto_fit self.button_list = [] # list of buttons to be disabled at startup self.fit_layout = None self.params_layout = None CurveWidgetMixin.__init__(self, wintitle=wintitle, icon=icon, toolbar=toolbar, options=options, panels=panels) self.refresh() # QWidget API -------------------------------------------------------------- def resizeEvent(self, event): QWidget.resizeEvent(self, event) self.get_plot().replot() # CurveWidgetMixin API ----------------------------------------------------- def setup_widget_layout(self): self.fit_layout = QHBoxLayout() self.params_layout = QGridLayout() params_group = create_groupbox(self, _("Fit parameters"), layout=self.params_layout) if self.auto_fit_enabled: auto_group = self.create_autofit_group() self.fit_layout.addWidget(auto_group) self.fit_layout.addWidget(params_group) self.plot_layout.addLayout(self.fit_layout, 1, 0) vlayout = QVBoxLayout(self) vlayout.addWidget(self.toolbar) vlayout.addLayout(self.plot_layout) self.setLayout(vlayout) def create_plot(self, options): CurveWidgetMixin.create_plot(self, options) for plot in self.get_plots(): plot.SIG_RANGE_CHANGED.connect(self.range_changed) # Public API --------------------------------------------------------------- def set_data(self, x, y, fitfunc=None, fitparams=None, fitargs=None, fitkwargs=None): if self.fitparams is not None and fitparams is not None: self.clear_params_layout() self.x = x self.y = y if fitfunc is not None: self.fitfunc = fitfunc if fitparams is not None: self.fitparams = fitparams if fitargs is not None: self.fitargs = fitargs if fitkwargs is not None: self.fitkwargs = fitkwargs self.autofit_prm = AutoFitParam(title=_("Automatic fitting options")) self.autofit_prm.xmin = x.min() self.autofit_prm.xmax = x.max() self.compute_imin_imax() if self.fitparams is not None and fitparams is not None: self.populate_params_layout() self.refresh() def set_fit_data(self, fitfunc, fitparams, fitargs=None, fitkwargs=None): if self.fitparams is not None: self.clear_params_layout() self.fitfunc = fitfunc self.fitparams = fitparams self.fitargs = fitargs self.fitkwargs = fitkwargs self.populate_params_layout() self.refresh() def clear_params_layout(self): for i, param in enumerate(self.fitparams): for widget in param.get_widgets(): if widget is not None: self.params_layout.removeWidget(widget) widget.hide() def populate_params_layout(self): add_fitparam_widgets_to(self.params_layout, self.fitparams, self.refresh, param_cols=self.param_cols) def create_autofit_group(self): auto_button = QPushButton(get_icon('apply.png'), _("Run"), self) auto_button.clicked.connect(self.autofit) autoprm_button = QPushButton(get_icon('settings.png'), _("Settings"), self) autoprm_button.clicked.connect(self.edit_parameters) xrange_button = QPushButton(get_icon('xrange.png'), _("Bounds"), self) xrange_button.setCheckable(True) xrange_button.toggled.connect(self.toggle_xrange) auto_layout = QVBoxLayout() auto_layout.addWidget(auto_button) auto_layout.addWidget(autoprm_button) auto_layout.addWidget(xrange_button) self.button_list += [auto_button, autoprm_button, xrange_button] return create_groupbox(self, _("Automatic fit"), layout=auto_layout) def get_fitfunc_arguments(self): """Return fitargs and fitkwargs""" fitargs = self.fitargs if self.fitargs is None: fitargs = [] fitkwargs = self.fitkwargs if self.fitkwargs is None: fitkwargs = {} return fitargs, fitkwargs def refresh(self, slider_value=None): """Refresh Fit Tool dialog box""" # Update button states enable = self.x is not None and self.y is not None \ and self.x.size > 0 and self.y.size > 0 \ and self.fitfunc is not None and self.fitparams is not None \ and len(self.fitparams) > 0 for btn in self.button_list: btn.setEnabled(enable) if not enable: # Fit widget is not yet configured return fitargs, fitkwargs = self.get_fitfunc_arguments() yfit = self.fitfunc(self.x, [p.value for p in self.fitparams], *fitargs, **fitkwargs) plot = self.get_plot() if self.legend is None: self.legend = make.legend(anchor=self.legend_anchor) plot.add_item(self.legend) if self.xrange is None: self.xrange = make.range(0., 1.) plot.add_item(self.xrange) self.xrange.set_range(self.autofit_prm.xmin, self.autofit_prm.xmax) self.xrange.setVisible(self.show_xrange) if self.data_curve is None: self.data_curve = make.curve([], [], _("Data"), color="b", linewidth=2) plot.add_item(self.data_curve) self.data_curve.set_data(self.x, self.y) if self.fit_curve is None: self.fit_curve = make.curve([], [], _("Fit"), color="r", linewidth=2) plot.add_item(self.fit_curve) self.fit_curve.set_data(self.x, yfit) plot.replot() plot.disable_autoscale() def range_changed(self, xrange_obj, xmin, xmax): self.autofit_prm.xmin, self.autofit_prm.xmax = xmin, xmax self.compute_imin_imax() def toggle_xrange(self, state): self.xrange.setVisible(state) plot = self.get_plot() plot.replot() if state: plot.set_active_item(self.xrange) self.show_xrange = state def edit_parameters(self): if self.autofit_prm.edit(parent=self): self.xrange.set_range(self.autofit_prm.xmin, self.autofit_prm.xmax) plot = self.get_plot() plot.replot() self.compute_imin_imax() def compute_imin_imax(self): self.i_min = self.x.searchsorted(self.autofit_prm.xmin) self.i_max = self.x.searchsorted(self.autofit_prm.xmax, side='right') def errorfunc(self, params): x = self.x[self.i_min:self.i_max] y = self.y[self.i_min:self.i_max] fitargs, fitkwargs = self.get_fitfunc_arguments() return y - self.fitfunc(x, params, *fitargs, **fitkwargs) def autofit(self): meth = self.autofit_prm.method x0 = np.array([p.value for p in self.fitparams]) if meth == "lq": x = self.autofit_lq(x0) elif meth=="simplex": x = self.autofit_simplex(x0) elif meth=="powel": x = self.autofit_powel(x0) elif meth=="bfgs": x = self.autofit_bfgs(x0) elif meth=="l_bfgs_b": x = self.autofit_l_bfgs(x0) elif meth=="cg": x = self.autofit_cg(x0) else: return for v, p in zip(x, self.fitparams): p.value = v self.refresh() for prm in self.fitparams: prm.update() def get_norm_func(self): prm = self.autofit_prm err_norm = eval(prm.err_norm) def func(params): err = np.linalg.norm(self.errorfunc(params), err_norm) return err return func def autofit_simplex(self, x0): prm = self.autofit_prm from scipy.optimize import fmin x = fmin(self.get_norm_func(), x0, xtol=prm.xtol, ftol=prm.ftol) return x def autofit_powel(self, x0): prm = self.autofit_prm from scipy.optimize import fmin_powell x = fmin_powell(self.get_norm_func(), x0, xtol=prm.xtol, ftol=prm.ftol) return x def autofit_bfgs(self, x0): prm = self.autofit_prm from scipy.optimize import fmin_bfgs x = fmin_bfgs(self.get_norm_func(), x0, gtol=prm.gtol, norm=eval(prm.norm)) return x def autofit_l_bfgs(self, x0): prm = self.autofit_prm bounds = [(p.min, p.max) for p in self.fitparams] from scipy.optimize import fmin_l_bfgs_b x, _f, _d = fmin_l_bfgs_b(self.get_norm_func(), x0, pgtol=prm.gtol, approx_grad=1, bounds=bounds) return x def autofit_cg(self, x0): prm = self.autofit_prm from scipy.optimize import fmin_cg x = fmin_cg(self.get_norm_func(), x0, gtol=prm.gtol, norm=eval(prm.norm)) return x def autofit_lq(self, x0): prm = self.autofit_prm def func(params): err = self.errorfunc(params) return err from scipy.optimize import leastsq x, _ier = leastsq(func, x0, xtol=prm.xtol, ftol=prm.ftol) return x def get_values(self): """Convenience method to get fit parameter values""" return [param.value for param in self.fitparams] class FitWidget(QWidget, FitWidgetMixin): def __init__(self, wintitle=None, icon="plotpy.svg", toolbar=False, options=None, parent=None, panels=None, param_cols=1, legend_anchor='TR', auto_fit=False): QWidget.__init__(self, parent) FitWidgetMixin.__init__(self, wintitle, icon, toolbar, options, panels, param_cols, legend_anchor, auto_fit)
[docs]class FitDialog(QDialog, FitWidgetMixin): def __init__(self, wintitle=None, icon="plotpy.svg", edit=True, toolbar=False, options=None, parent=None, panels=None, param_cols=1, legend_anchor='TR', auto_fit=False): if not PYQT5: QDialog.__init__(self, parent) self.edit = edit self.button_layout = None if PYQT5: super(FitDialog, self).__init__(parent, wintitle=wintitle, icon=icon, toolbar=toolbar, options=options, panels=panels, param_cols=param_cols, legend_anchor=legend_anchor, auto_fit=auto_fit) else: FitWidgetMixin.__init__(self, wintitle, icon, toolbar, options, panels, param_cols, legend_anchor, auto_fit) self.setWindowFlags(Qt.Window) def setup_widget_layout(self): FitWidgetMixin.setup_widget_layout(self) if self.edit: self.install_button_layout() def install_button_layout(self): bbox = QDialogButtonBox(QDialogButtonBox.Ok | QDialogButtonBox.Cancel) bbox.accepted.connect(self.accept) bbox.rejected.connect(self.reject) self.button_list += [bbox.button(QDialogButtonBox.Ok)] self.button_layout = QHBoxLayout() self.button_layout.addStretch() self.button_layout.addWidget(bbox) vlayout = self.layout() vlayout.addSpacing(10) vlayout.addLayout(self.button_layout)
[docs]def guifit(x, y, fitfunc, fitparams, fitargs=None, fitkwargs=None, wintitle=None, title=None, xlabel=None, ylabel=None, param_cols=1, auto_fit=True, winsize=None, winpos=None): """GUI-based curve fitting tool""" _app = guidata.qapplication() # win = FitWidget(wintitle=wintitle, toolbar=True, # param_cols=param_cols, auto_fit=auto_fit, # options=dict(title=title, xlabel=xlabel, ylabel=ylabel)) win = FitDialog(edit=True, wintitle=wintitle, toolbar=True, param_cols=param_cols, auto_fit=auto_fit, options=dict(title=title, xlabel=xlabel, ylabel=ylabel)) win.set_data(x, y, fitfunc, fitparams, fitargs, fitkwargs) if winsize is not None: win.resize(*winsize) if winpos is not None: win.move(*winpos) if win.exec_(): return win.get_values() # win.show() # _app.exec_() # return win.get_values()
if __name__ == "__main__": x = np.linspace(-10, 10, 1000) y = np.cos(1.5*x)+np.random.rand(x.shape[0])*.2 def fit(x, params): a, b = params return np.cos(b*x)+a a = FitParam("Offset", 1., 0., 2.) b = FitParam("Frequency", 1.05, 0., 10., logscale=True) params = [a, b] values = guifit(x, y, fit, params, auto_fit=True) print(values) print([param.value for param in params])