:mod:`nolearn.dbn` ------------------ API ~~~ .. automodule:: nolearn.dbn .. autoclass:: DBN :special-members: :members: Example: MNIST ~~~~~~~~~~~~~~ Let's train 2-layer neural network to do digit recognition on the `MNIST dataset `_. We first load the MNIST dataset, and split it up into a training and a test set: .. code-block:: python from sklearn.cross_validation import train_test_split from sklearn.datasets import fetch_mldata mnist = fetch_mldata('MNIST original') X_train, X_test, y_train, y_test = train_test_split( mnist.data / 255.0, mnist.target) We then configure a neural network with 300 hidden units, a learning rate of ``0.3`` and a learning rate decay of ``0.9``, which is the number that the learning rate will be multiplied with after each epoch. .. code-block:: python from nolearn.dbn import DBN clf = DBN( [X_train.shape[1], 300, 10], learn_rates=0.3, learn_rate_decays=0.9, epochs=10, verbose=1, ) Let us now train our network for 10 epochs. This will take around five minutes on a CPU: .. code-block:: python clf.fit(X_train, y_train) After training, we can use our trained neural network to predict the examples in the test set. We'll observe that our model has an accuracy of around **97.5%**. .. code-block:: python from sklearn.metrics import classification_report from sklearn.metrics import zero_one_score y_pred = clf.predict(X_test) print "Accuracy:", zero_one_score(y_test, y_pred) print "Classification report:" print classification_report(y_test, y_pred) Example: Iris ~~~~~~~~~~~~~ In this example, we'll train a neural network for classification on the `Iris flower data set `_. Due to the small number of examples, an SVM will typically perform better, but let us still see if our neural network is up to the task: .. code-block:: python from sklearn.cross_validation import cross_val_score from sklearn.datasets import load_iris from sklearn.preprocessing import scale iris = load_iris() clf = DBN( [4, 4, 3], learn_rates=0.3, epochs=50, ) scores = cross_val_score(clf, scale(iris.data), iris.target, cv=10) print "Accuracy: %0.2f (+/- %0.2f)" % (scores.mean(), scores.std() / 2) This will print something like:: Accuracy: 0.97 (+/- 0.03) Example: CIFAR-10 ~~~~~~~~~~~~~~~~~ In this example, we'll train a neural network to do image classification using a subset of the `CIFAR-10 dataset `_. We assume that you have the Python version of the CIFAR-10 dataset downloaded and available in your working directory. We'll use only the first three batches of the dataset; the first two for training, the third one for testing. Let us load the dataset: .. code-block:: python import cPickle import numpy as np def load(name): with open(name, 'rb') as f: return cPickle.load(f) dataset1 = load('data_batch_1') dataset2 = load('data_batch_2') dataset3 = load('data_batch_3') data_train = np.vstack([dataset1['data'], dataset2['data']]) labels_train = np.hstack([dataset1['labels'], dataset2['labels']]) data_train = data_train.astype('float') / 255. labels_train = labels_train data_test = dataset3['data'].astype('float') / 255. labels_test = np.array(dataset3['labels']) We can now train our network. We'll configure the network so that it has 1024 units in the hidden layer, i.e. ``[3072, 1024, 10]``. We'll train our network for 50 epochs, which will take a while if you're not using `CUDAMat `_. .. code-block:: python n_feat = data_train.shape[1] n_targets = labels_train.max() + 1 net = DBN( [n_feat, n_feat / 3, n_targets], epochs=50, learn_rates=0.03, verbose=1, ) net.fit(data_train, labels_train) Finally, we'll look at our network's performance: .. code-block:: python from sklearn.metrics import classification_report from sklearn.metrics import confusion_matrix expected = labels_test predicted = net.predict(data_test) print "Classification report for classifier %s:\n%s\n" % ( net, classification_report(expected, predicted)) print "Confusion matrix:\n%s" % confusion_matrix(expected, predicted) You should see an f1-score of **0.49** and a confusion matrix that looks something like this:: air aut bir cat dee dog fro hor shi tru [[459 48 66 39 91 21 5 39 182 44] airplane [ 28 584 12 31 23 22 8 29 117 188] automobile [ 49 13 279 101 244 124 31 71 37 16] bird [ 20 16 54 363 106 255 38 70 36 39] cat [ 33 10 79 81 596 66 15 75 26 9] deer [ 16 23 57 232 103 448 17 82 26 25] dog [ 10 18 70 179 212 106 304 32 21 26] frog [ 20 8 40 80 125 98 10 575 21 38] horse [ 54 49 10 29 43 25 4 9 707 31] ship [ 28 129 9 48 33 36 10 57 118 561]] truck We should be able to improve on this score by using the full dataset and by training longer.