Source code for hessianfree.optimizers

from __future__ import print_function

import warnings
from collections import defaultdict

import numpy as np


[docs]class Optimizer(object): """Base class for optimizers. Each optimizer has a ``self.net`` parameter that will be set automatically when the optimizer is added to a network (referring to that network).""" def __init__(self): self.net = None
[docs] def compute_update(self, printing=False): """Compute a weight update for the current batch. It can be assumed that the batch has already been stored in ``net.inputs`` and ``net.targets``, and the nonlinearity activations/derivatives for the batch are cached in ``net.activations`` and ``net.d_activations``. :param bool printing: if True, print out data about the optimization """ raise NotImplementedError()
[docs]class HessianFree(Optimizer): """Use Hessian-free optimization to compute the weight update. :param int CG_iter: maximum number of CG iterations to run per epoch :param float init_damping: the initial value of the Tikhonov damping :param bool plotting: if True then collect data for plotting (actual plotting handled in parent network) """ def __init__(self, CG_iter=250, init_damping=1, plotting=True): super(HessianFree, self).__init__() self.CG_iter = CG_iter self.init_delta = None self.damping = init_damping self.plotting = plotting self.plots = defaultdict(list)
[docs] def compute_update(self, printing=False): """Compute a weight update for the current batch. :param bool printing: if True, print out data about the optimization """ err = self.net.error() # note: don't reuse previous error (diff batch) # compute gradient grad = self.net.calc_grad() if printing: print("initial err", err) print("grad norm", np.linalg.norm(grad)) # run CG if self.init_delta is None: self.init_delta = np.zeros_like(self.net.W) deltas = self.conjugate_gradient(self.init_delta * 0.95, grad, iters=self.CG_iter, printing=printing and self.net.debug) if printing: print("CG steps", deltas[-1][0]) self.init_delta = deltas[-1][1] # note: don't backtrack this # CG backtracking new_err = np.inf for j in range(len(deltas) - 1, -1, -1): prev_err = self.net.error(self.net.W + deltas[j][1]) # note: we keep using the cached inputs, not rerunning the plant # (if there is one). that is, we are evaluating whether the update # improves on those inputs, not whether it improves the overall # objective. we could do the latter instead, but it makes things # more prone to instability. if prev_err > new_err: break delta = deltas[j][1] new_err = prev_err else: j -= 1 if printing: print("using iteration", deltas[j + 1][0]) print("backtracked err", new_err) # update damping parameter (compare improvement predicted by # quadratic model to the actual improvement in the error) quad = (0.5 * np.dot(self.calc_G(delta, damping=self.damping), delta) + np.dot(grad, delta)) improvement_ratio = ((new_err - err) / quad) if quad != 0 else 1 if improvement_ratio < 0.25: self.damping *= 1.5 elif improvement_ratio > 0.75: self.damping *= 0.66 if printing: print("improvement_ratio", improvement_ratio) print("damping", self.damping) # line search to find learning rate l_rate = 1.0 min_improv = min(1e-2 * np.dot(grad, delta), 0) for _ in range(60): # check if the improvement is greater than the minimum # improvement we would expect based on the starting gradient if new_err <= err + l_rate * min_improv: break l_rate *= 0.8 new_err = self.net.error(self.net.W + l_rate * delta) else: # no good update, so skip this iteration l_rate = 0.0 new_err = err if printing: print("min_improv", min_improv) print("l_rate", l_rate) print("l_rate err", new_err) print("improvement", new_err - err) if self.plotting: self.plots["training error (log)"] += [new_err] self.plots["learning rate"] += [l_rate] self.plots["damping (log)"] += [self.damping] self.plots["CG iterations"] += [deltas[-1][0]] self.plots["backtracked steps"] += [deltas[-1][0] - deltas[j + 1][0]] return l_rate * delta
[docs] def conjugate_gradient(self, init_delta, grad, iters=250, printing=False): """Find minimum of quadratic approximation using conjugate gradient algorithm.""" if self.net.debug: self.net.check_grad(grad) store_iter = 5 store_mult = 1.3 deltas = [] grad = -grad # note negative, some CG algorithms are flipped vals = np.zeros(iters, dtype=self.net.dtype) if self.net.use_GPU: from pycuda import gpuarray base_grad = gpuarray.to_gpu(grad) delta = gpuarray.to_gpu(init_delta) G_dir = gpuarray.zeros(grad.shape, dtype=self.net.dtype) self.calc_G = self.net.GPU_calc_G def dot(a, b): return gpuarray.dot(a, b).get() def get(x): return x.get(pagelocked=True) else: base_grad = grad delta = init_delta G_dir = np.zeros_like(grad) self.calc_G = self.net.calc_G dot = np.dot get = np.copy residual = base_grad.copy() residual -= self.calc_G(delta, damping=self.damping, out=G_dir) res_norm = dot(residual, residual) direction = residual.copy() for i in range(iters): if printing: print("-" * 20) print("CG iteration", i) print("delta norm", np.linalg.norm(get(delta))) print("direction norm", np.linalg.norm(get(direction))) self.calc_G(direction, damping=self.damping, out=G_dir) # calculate step size step = res_norm / dot(direction, G_dir) if not np.isfinite(step): warnings.warn("Non-finite step value (%f)" % step) break if printing: print("G_dir norm", np.linalg.norm(get(G_dir))) print("step", step) if self.net.debug: tmp_G_dir = get(G_dir) tmp_dir = get(direction) self.net.check_G(tmp_G_dir, tmp_dir, self.damping) assert np.isfinite(step) assert step >= 0 assert (np.linalg.norm(np.dot(tmp_dir, tmp_G_dir)) >= np.linalg.norm(np.dot(tmp_dir, self.net.calc_G(tmp_dir, damping=0)))) # update weight delta delta += step * direction # update residual residual -= step * G_dir new_res_norm = dot(residual, residual) if new_res_norm < 1e-20: # early termination (mainly to prevent numerical errors); # the main termination condition is below. break # update direction beta = new_res_norm / res_norm direction *= beta direction += residual res_norm = new_res_norm # store deltas for backtracking if i == store_iter: deltas += [(i, get(delta))] store_iter = int(store_iter * store_mult) # martens termination conditions vals[i] = -0.5 * dot(residual + base_grad, delta) gap = max(int(0.1 * i), 10) if printing: print("termination val", vals[i]) if (i > gap and vals[i - gap] < 0 and (vals[i] - vals[i - gap]) / vals[i] < 5e-6 * gap): break deltas += [(i, get(delta))] return deltas
[docs]class SGD(Optimizer): """Compute weight update using first-order gradient descent. :param l_rate: learning rate to apply to weight updates :param plotting: if True then collect data for plotting (actual plotting handled in parent network)""" def __init__(self, l_rate=1, plotting=False): super(SGD, self).__init__() self.l_rate = l_rate self.plotting = plotting self.plots = defaultdict(list)
[docs] def compute_update(self, printing=False): """Compute a weight update for the current batch. :param bool printing: if True, print out data about the optimization """ grad = self.net.calc_grad() if self.net.debug: self.net.check_grad(grad) if printing: train_err = self.net.error() print("training error", train_err) # note: for SGD we'll just do the plotting when we print (since # we're going to be doing a lot more, and smaller, updates) if self.plotting: self.plots["training error"] += [train_err] return -self.l_rate * grad