Source code for infpy.decision.test.genepy_test
#
# Copyright John Reid 2007, 2010
#
import unittest, logging
from infpy.genepy import Population
from infpy.decision import Context, OrdinalAttribute, EnumerativeAttribute, ContinuousAttribute, DecisionTreeSpecies, count_nodes, log_tree
[docs]def ord(data): return data[0]
[docs]def enum(data): return data[1]
[docs]def cont(data): return data[2]
[docs]class GenepyTest( unittest.TestCase ):
"""Test case for decision tree genepy interface"""
# define the context - i.e. what attributes the data have and their
# classification
context = Context()
context.attributes.append(OrdinalAttribute('ordinal', ord, 10))
context.attributes.append(EnumerativeAttribute('enumerative', enum, 4))
context.attributes.append(ContinuousAttribute('continuous', cont, 0.0, 1.0))
context.outcomes = [ 0, 1, 2 ]
# generate some data
from random import gauss, randint
data = [
((randint(0,9),randint(0,3),gauss(0.0,1.0)),randint(0,2))
for i in xrange(30)
]
species = DecisionTreeSpecies(context)
species.data = data
[docs] def fitness_fn(self, individual):
loss = 0
for x, y in self.data:
if y != individual(x):
loss += 1
return loss
[docs] def test_initialisation(self):
self.species.random_individual()
[docs] def test_mutation(self):
pass #self.species.mutate(self.species.random_individual())
[docs] def test_combination(self):
self.species.mate(self.species.random_individual(), self.species.random_individual())
[docs] def test_call(self):
for x, _y in self.data: self.species.random_individual()(x)
[docs] def test_overall(self):
pop = Population(size = 100, species = self.species, fitness_fn = self.fitness_fn)
pop.post_generation_process = self.species.prune_individuals
for gen_idx in xrange(100):
# execute a generation
pop.generation()
# and print how good best is
tree_sizes = list(count_nodes(i) for i in pop.individuals)
logging.debug("%3d; best fitness=%2d; average tree size: %3.3f", gen_idx, int(pop.most_fit.fitness), float(sum(tree_sizes))/len(tree_sizes))
log_tree(pop.most_fit, logging.getLogger(), level=logging.DEBUG)
if __name__ == "__main__":
logging.basicConfig(level=logging.DEBUG)
unittest.main()