Source code for infpy.exp.test.exp_family_test

# Copyright John Reid 2007

import unittest, infpy.exp, scipy.stats, numpy, math
from numpy import log, outer, dot, asarray, zeros, empty, identity
from numpy.linalg import det, inv
from scipy.special import digamma

_epsilon = math.sqrt(numpy.finfo(float).eps)

[docs]def check_gradient(f, df, x0, tol=2e-2, text=''): "Checks the gradient of f matches the gradient calculated by df at x0" from scipy.optimize import approx_fprime approx = approx_fprime(x0, f, _epsilon) calculated = df(x0) diff = math.sqrt(sum((calculated-approx)**2)) assert diff < tol, \ "%s: Gradients don't match\nx0: %s\ndistance: %f\ncalculated: %s\napprox: %s\ndiff: %s" % ( text, x0, diff, calculated, approx, calculated-approx )
[docs]def check_close(a, b, tol=1e-8): #special case for -1.#INF for discrete distribution try: if (a == b).all(): return True except: pass diff = numpy.array(a) - numpy.array(b) diff_size =, diff) return (diff_size < tol).all()
[docs]def check_is_close(u, v, tol=1e-4, strong_test=True, dont_test_zeros=False): """ See U{} for partial inspiration. """ u = numpy.asarray(u) v = numpy.asarray(v) assert u.shape == v.shape, 'u and v are not the same shape: %s != %s' % (str(u.shape), str(v.shape)) if (u == v).all(): return True abs_u = numpy.fabs(u) abs_v = numpy.fabs(v) diff = numpy.fabs(u - v) diff_u = diff / abs_u diff_v = diff / abs_v check_u = (diff_u <= tol) check_v = (diff_v <= tol) if dont_test_zeros: # if value is 0.0 allow any difference check_u += (0.0 == abs_u) check_v += (0.0 == abs_v) #print check_u, check_v if strong_test: return (check_u * check_v).all() else: return (check_u + check_v).all()
[docs]class CheckExpFamily(object): "Checks various methods/properties of an exponential family." def __init__(self, family): = family "The exponential family being checked."
[docs] def check_attributes(self): "Checks the class has the correct attributes." assert hasattr(, 'T') assert hasattr(, 'eta') assert hasattr(, 'A') assert hasattr(, 'h') assert hasattr(, 'theta') assert hasattr(, 'x') assert hasattr(, 'dimension') assert hasattr(, '_p_truth') assert hasattr(, '_typical_xs') assert hasattr(, '_typical_thetas')
[docs] def check_xs(self): "Checks the family converts between x and T ok." for x in T = if tuple == type(x): test_x = for x_i, test_x_i in zip(x, test_x): assert check_is_close(x_i, test_x_i), str(x) else: assert check_is_close(x,, str(x) assert len(T) == assert check_is_close(T,, str(T)
[docs] def check_dimensions(self): "Checks the family has the correct length for A(), T() and eta()." for x in T = T_shape = T.shape expected_T_shape = (,) assert T_shape == expected_T_shape, "%s: T: incorrect shape/dimensions: %s != %s" % (, str(T_shape), str(expected_T_shape) ) for theta in eta = eta_shape = eta.shape expected_eta_shape = (,) assert eta_shape == expected_eta_shape, "%s: eta: incorrect shape/dimensions: %s != %s" % (, str(eta_shape), str(expected_eta_shape) ) A_shape = expected_A_shape = (,) assert A_shape == expected_A_shape, "%s: A: Incorrect normalisation dimension: %s != %s" % (, str(A_shape), str(expected_A_shape) )
[docs] def check_thetas(self): "Checks the family converts between eta and theta ok." for theta in eta = assert len(eta) == assert check_is_close(eta,, strong_test=False, dont_test_zeros=True), str(eta)
[docs] def check_against_truth(self): "Checks the family pdf agrees with the independently calculated pdf." for x in for theta in calculated =, theta) truth =, theta) #print x, theta, calculated, truth assert check_is_close(calculated, truth, strong_test=False, dont_test_zeros=True), \ '%s: True pdf does not match calculated pdf\nx:%s\ntheta:%s\ncalculated:%f\ntruth:%f' % ( str(, str(x), str(theta), calculated, truth )
[docs] def check_entropy_truth(self): "Checks the family entropy agrees with the independently calculated entropy." for theta in eta = calculated = truth = assert check_is_close(calculated, truth), \ '%s: True entropy does not match calculated entropy\ntheta:%s; eta:%s; calculated:%f; truth:%f' % ( str(, str(theta), str(eta), calculated, truth )
[docs] def check_exp_T(self, size=10000): "Checks the expectation of sufficient statistics." for theta in eta = # check that the derivative of the log normalisation factor is the expectation of the sufficient statistic #check_gradient( #, #, # eta, # text=str( #) # check the expectation of the sufficient statistic by sampling exp_T =, size).sum(axis=0)/size assert check_is_close(exp_T,, tol=.4, strong_test=False, dont_test_zeros=True), \ '%s: expected T != sampled T\ntheta:%s\neta:%s\nsampled:%s\ncalculated:%s' % ( str(, str(theta), str(eta), str(exp_T), str( )
def _sample_KL(self, eta1, eta2, samples): return sum((, eta1) -, eta2)) for sample in samples) / len(samples) def _check_KL(self, eta1, eta2, samples): calculated_KL =, eta2) sampled_KL = self._sample_KL(eta1, eta2, samples) return calculated_KL, sampled_KL
[docs] def check_KL(self, size=1000): "Checks the KL divergence by sampling." for theta1 in eta1 = samples =, size) for theta2 in eta2 = calculated_KL, sampled_KL = self._check_KL(eta1, eta2, samples) if (eta1 == eta2).all(): # KL should be 0 if eta1 == eta2 assert check_is_close( 1.0, calculated_KL + 1., tol=1e-10, strong_test=True, dont_test_zeros=False ), '%s: KL should be 0.0\ntheta1:%s\ntheta2:%s\ncalculated:%s' % ( str(, str(theta1), str(theta2), str(calculated_KL) ) # check the KL by sampling assert check_is_close( sampled_KL, calculated_KL, tol=.3, strong_test=False, dont_test_zeros=True ), '%s: expected KL != sampled KL\ntheta1:%s\ntheta2:%s\nsampled:%s\ncalculated:%s' % ( str(, str(theta1), str(theta2), str(sampled_KL), str(calculated_KL) ) return sampled_KL - calculated_KL
[docs] def check_vectorisation(self): "Checks the family handles multiple xs and etas at once." if xs = asarray( thetas = etas = empty((len(thetas), for k, theta in enumerate(thetas): etas[k] = Ts = for x, T in zip(xs, Ts): assert ( == T).all() _xs = for x, _x in zip(xs, _xs): check_is_close(x, _x) _thetas = for theta, _theta in zip(thetas, _thetas): check_is_close(theta, _theta) if hasattr(, 'exp_T'): exp_Ts = for eta, exp_T in zip(etas, exp_Ts): assert ( == exp_T).all() As = for A, eta in zip(As, etas): check_is_close(A,
[docs] def check_all(self): self.check_attributes() self.check_xs() self.check_thetas() self.check_dimensions() self.check_against_truth() self.check_vectorisation() if hasattr(, 'exp_T'): self.check_exp_T() if hasattr(, 'KL'): self.check_KL() if hasattr(, '_entropy_truth'): self.check_entropy_truth()
[docs]class GeneralisedExpFamilyTests(unittest.TestCase): """ Test case for all exponential families with default __init__ arguments. """ def _test_subclasses(self, cls): "Test all the subclasses of cls." for subclass in cls.__subclasses__(): #print subclass CheckExpFamily(subclass()).check_all() self._test_subclasses(subclass)
[docs] def test(self): "Test all the subclasses of infpy.exp.ExponentialFamily." self._test_subclasses(infpy.exp.ExponentialFamily)
[docs]class CheckConjugatePrior(object): """ Checks various methods/properties of a conjugate prior. """ def __init__(self, conj_prior): self.conj_prior = conj_prior "The conjugate prior being checked."
[docs] def check_conjugacy(self): """ Test that the conjugate prior is set up correctly. """ for theta in self.conj_prior.likelihood._typical_thetas: eta = self.conj_prior.likelihood.eta(theta) A = self.conj_prior.likelihood.A_vec(eta) _lambda = self.conj_prior.prior.T(theta) _lambda_2 = _lambda[:self.conj_prior.strength_dimension] assert check_is_close( _lambda_2, -A, tol=1.-5, strong_test=True, dont_test_zeros=False ), \ 'lambda_2 != -A\neta=%s\nlambda=%s\nlambda_2=%s\nA=%s' % ( str(eta), str(_lambda), str(_lambda_2), str(A) )
[docs] def check_log_normalisation_factor_expectation(self, sample_size=1000): """ Test that we get close to the same expectation of the mvn normalisation factor from sampling from a normal wishart that we do from calculation of our formula. """ for theta in self.conj_prior.prior._typical_thetas: eta = self.conj_prior.prior.eta(theta) calculated_lnf = self.conj_prior.exp_likelihood_log_normalisation_factor(eta) samples = self.conj_prior.prior.sample(eta, size=sample_size) sampled_lnf = 0. for sample in samples: sampled_lnf += self.conj_prior.likelihood.A(sample[self.conj_prior.strength_dimension:]) sampled_lnf /= sample_size assert check_is_close( calculated_lnf, sampled_lnf, tol=.1, strong_test=False, dont_test_zeros=True ), \ '<log A(eta)>: expected != sampled\neta:%s\ncalculated:%s\nsampled:%s' % ( str(eta), str(calculated_lnf), str(sampled_lnf) )
[docs] def check_all(self): """ Run all checks on this conjugate prior. """ self.check_conjugacy() self.check_log_normalisation_factor_expectation()
[docs]class GeneralisedConjPriorTests(unittest.TestCase): """ Test case for all conjugate priors. """ def _test_subclasses(self, cls): "Test all the subclasses of cls." for subclass in cls.__subclasses__(): #print subclass CheckConjugatePrior(subclass()).check_all() self._test_subclasses(subclass)
[docs] def test(self): "Test all the subclasses of infpy.exp.ConjugatePrior." self._test_subclasses(infpy.exp.ConjugatePrior)
[docs]class MvnConjPriorTests(unittest.TestCase): """ Tests conjugate prior of MVN distribution. """
[docs] def test_expectations_by_sampling(self, sample_size=1000): """ Test that we get close to the same expectation of the mvn normalisation factor from sampling from a normal wishart that we do from calculation of our formula. """ from infpy.exp import WishartExpFamily, MvnExpFamily k = 2 wishart = WishartExpFamily(k) mvn = MvnExpFamily(k) for kappa_0 in [ 1., 2. ]: for mu_0 in [ [0., 0.], [5., -3.] ]: for nu, S in WishartExpFamily._typical_thetas: calculated_mu_W_mu = k/kappa_0 + nu * dot(mu_0, dot(S, mu_0)) calculated_log_W = sum(digamma((nu-i)/2.) for i in xrange(k)) + log(det(S)) + k*log(2.) sampled_W = wishart.sample(eta=wishart.eta((nu,S)), size=sample_size) sampled_mu_W_mu = 0.0 sampled_log_W = 0.0 for T_W in sampled_W: W = wishart.x(T_W) #print log(det(W)) sampled_log_W += log(det(W)) precision = kappa_0 * W mu = mvn.x(mvn.sample(eta=mvn.eta((mu_0, precision)), size=1)[0]) sampled_mu_W_mu += dot(mu, dot(W, mu)) #print W, mu sampled_mu_W_mu /= sample_size sampled_log_W /= sample_size assert check_is_close( calculated_mu_W_mu, sampled_mu_W_mu, tol=.1, strong_test=False, dont_test_zeros=True ), \ 'log|W|: expected != sampled\nkappa_0:%s\nmu_0:%s\nnu:%s\nS:%s\nsampled:%s\ncalculated:%s' % ( str(kappa_0), str(mu_0), str(nu), str(S), str(sampled_mu_W_mu), str(calculated_mu_W_mu) ) assert check_is_close( calculated_log_W, sampled_log_W, tol=.3, strong_test=False, dont_test_zeros=True ), \ 'log|W|: expected != sampled\nkappa_0:%s\nmu_0:%s\nnu:%s\nS:%s\nsampled:%s\ncalculated:%s' % ( str(kappa_0), str(mu_0), str(nu), str(S), str(sampled_log_W), str(calculated_log_W) )
[docs]class HighDMvnTest(unittest.TestCase): """ Test case for high dimensional multivariate normal. """
[docs] def setUp(self): # Create a sequence of exponential families of high dimension self.families = [ infpy.exp.MvnExpFamily(k=dim) for dim in xrange(2,10) # check up to 10-D ] # give each family a typical set of xs and thetas for family in self.families: family._typical_xs = [ zeros((family.k,)) + 1.1 ] family._typical_thetas = [ ( zeros((family.k,)) + 1.01, identity(family.k) ) ]
[docs] def test_exp_family_std_methods(self): for family in self.families: # run the tests CheckExpFamily(family).check_all()
[docs] def test_expectations_by_sampling(self, size=1000): "Test some expected values by sampling from the distribution." for family in self.families: for theta in family._typical_thetas: eta = family.eta(theta) sample = family.sample(eta, size) xs = [family.x(T) for T in sample] mu, W = theta # check expected value of x sampled_x = sum(xs)/size exp_x = mu assert check_is_close(exp_x, sampled_x, tol=.3, strong_test=False, dont_test_zeros=True), \ 'x: expected != sampled\ntheta:%s\neta:%s\nsampled:%s\ncalculated:%s' % ( str(theta), str(eta), str(sampled_x), str(exp_x) ) # check expected value of x.x' sampled_x2 = sum(outer(x, x) for x in xs)/size exp_x2 = outer(mu, mu) + inv(W) assert check_is_close(exp_x2, sampled_x2, tol=.2, strong_test=False, dont_test_zeros=True), \ 'x.x\': expected != sampled\ntheta:%s\neta:%s\nsampled:%s\ncalculated:%s' % ( str(theta), str(eta), str(sampled_x2), str(exp_x2) )
if __name__ == "__main__": # GeneralisedConjPriorTests('test').debug() # GeneralisedExpFamilyTests('test').debug() #CheckExpFamily(infpy.exp.GaussianExpFamily()).check_all() #CheckExpFamily(infpy.exp.NormalGammaExpFamily()).check_all() #CheckConjugatePrior(infpy.exp.GaussianConjugatePrior()).check_all() #CheckExpFamily(infpy.exp.DirichletExpFamily()).check_all() #HighDMvnTest('test_exp_family_std_methods').debug() #GeneralisedExpFamilyTests('test').debug() unittest.main()