This module contains various provisions for cross-validation.
The main functions in this module are:
Module author: Marc Claesen
Selects the subset specified by indices from collection.
>>> select([0, 1, 2, 3, 4], [1, 3])
[1, 3]
Returns a list containing a random permutation of r elements out of data.
Parameters: | data – an iterable containing the elements to permute over |
---|---|
Returns: | returns a list containing permuted entries of data. |
Function decorator to perform cross-validation as configured.
Parameters: |
|
---|---|
Returns: | a cross_validated_callable with the proper configuration. |
This resulting decorator must be used on a function with the following signature (+ potential other arguments):
Parameters: |
|
---|
y_train and y_test must be available of the y argument to this function is not None.
These 4 keyword arguments will be bound upon decoration. Further arguments will remain free (e.g. hyperparameter names).
>>> data = list(range(5))
>>> @cross_validated(x=data, num_folds=5, folds=[[[i] for i in range(5)]], aggregator=identity)
... def f(x_train, x_test, a):
... return x_test[0] + a
>>> f(a=1)
[1, 2, 3, 4, 5]
>>> f(1)
[1, 2, 3, 4, 5]
>>> f(a=2)
[2, 3, 4, 5, 6]
Generates folds for a given number of rows.
Parameters: |
|
---|---|
Returns: | a list of folds, each fold is a list of instance indices |
>>> folds = generate_folds(num_rows=6, num_folds=2, clusters=[[1, 2]], strata=[[3,4]])
>>> len(folds)
2
>>> i1 = [idx for idx, fold in enumerate(folds) if 1 in fold]
>>> i2 = [idx for idx, fold in enumerate(folds) if 2 in fold]
>>> i1 == i2
True
>>> i3 = [idx for idx, fold in enumerate(folds) if 3 in fold]
>>> i4 = [idx for idx, fold in enumerate(folds) if 4 in fold]
>>> i3 == i4
False
Warning
Instances in strata are not necessarily spread out over all folds. Some folds may already be full due to clusters. This effect should be negligible.
Constucts a list of strata (lists) based on unique values of labels.
Parameters: | labels – iterable, identical values will end up in identical strata |
---|---|
Returns: | the strata, as a list of lists |
Computes means of consequent elements in given list.
Parameters: | list_of_measures (list) – a list of tuples to compute means from |
---|---|
Returns: | a list containing the means |
This function can be used as an aggregator in cross_validated(), when multiple performance measures are being returned by the wrapped function.
>>> list_mean([(1, 4), (2, 5), (3, 6)])
[2.0, 5.0]