Source code for lspi.lspi

# -*- coding: utf-8 -*-
"""Contains main interface to LSPI algorithm."""

from copy import copy

import numpy as np


[docs]def learn(data, initial_policy, solver, epsilon=10**-5, max_iterations=10): r"""Find the optimal policy for the specified data. Parameters ---------- data: Generally a list of samples, however, the type of data does not matter so long as the specified solver can handle it in its solve routine. For example when doing model based learning one might pass in a model instead of sample data initial_policy: Policy Starting policy. A copy of this policy will be made at the start of the method. This means that the provided initial policy will be preserved. solver: Solver A subclass of the Solver abstract base class. This class must implement the solve method. Examples of solvers might be steepest descent or any other linear system of equation matrix solver. This is basically going to be implementations of the LSTDQ algorithm. epsilon: float The threshold of the change in policy weights. Determines if the policy has converged. When the L2-norm of the change in weights is less than this value the policy is considered converged max_iterations: int The maximum number of iterations to run before giving up on convergence. The change in policy weights are not guaranteed to ever go below epsilon. To prevent an infinite loop this parameter must be specified. Return ------ Policy The converged policy. If the policy does not converge by max_iterations then this will be the last iteration's policy. Raises ------ ValueError If epsilon is <= 0 ValueError If max_iteration <= 0 """ if epsilon <= 0: raise ValueError('epsilon must be > 0: %g' % epsilon) if max_iterations <= 0: raise ValueError('max_iterations must be > 0: %d' % max_iterations) # this is just to make sure that changing the weight vector doesn't # affect the original policy weights curr_policy = copy(initial_policy) distance = float('inf') iteration = 0 while distance > epsilon and iteration < max_iterations: iteration += 1 new_weights = solver.solve(data, curr_policy) distance = np.linalg.norm(new_weights - curr_policy.weights) curr_policy.weights = new_weights return curr_policy