Source code for neurolab.train
# -*- coding: utf-8 -*-
"""
Train algorithms based gradients algorithms
===========================================
.. autofunction:: train_gd
.. autofunction:: train_gdm
.. autofunction:: train_gda
.. autofunction:: train_gdx
.. autofunction:: train_rprop
Train algorithms based on Winner Take All - rule
================================================
.. autofunction:: train_wta
.. autofunction:: train_cwta
Train algorithms based on spipy.optimize
========================================
.. autofunction:: train_bfgs
.. autofunction:: train_cg
.. autofunction:: train_ncg
Train algorithms for LVQ networks
=================================
.. autofunction:: train_lvq
Delta rule
==========
.. autofunction:: train_delta
"""
from . import gd, spo, wta, lvq, delta
import functools
def trainer(Train, **kwargs):
""" Trainner init """
from neurolab.core import Trainer
#w = functools.wraps(Train)
#c = w(Trainer(Train))
c = Trainer(Train, **kwargs)
c.__doc__ = Train.__doc__
c.__name__ = Train.__name__
c.__module__ = Train.__module__
return c
# Initializing mains train functors
train_gd = trainer(gd.TrainGD)
#train_gd2 = trainer(gd.TrainGD2)
train_gdm = trainer(gd.TrainGDM)
train_gda = trainer(gd.TrainGDA)
train_gdx = trainer(gd.TrainGDX)
train_rprop = trainer(gd.TrainRprop)
#train_rpropm = trainer(gd.TrainRpropM)
train_bfgs = trainer(spo.TrainBFGS)
train_cg = trainer(spo.TrainCG)
train_ncg = trainer(spo.TrainNCG)
train_wta = trainer(wta.TrainWTA)
train_cwta = trainer(wta.TrainCWTA)
train_lvq = trainer(lvq.TrainLVQ)
train_delta = trainer(delta.TrainDelta)