#
# Copyright John Reid 2007, 2008, 2009. 2013
#
"""
Code to implement ROC point/curve calculation and plotting.
"""
import math, numpy as N
from itertools import chain, groupby
import logging
import warnings
logging.basicConfig(level=logging.INFO)
[docs]class RocCalculator(object):
"""
Calculates specificities and sensitivities from counts of true and false
positives and negatives.
Source: wikipedia - Fawcett (2004)
"""
def __init__(self, tp=0, fp=0, tn=0, fn=0):
self.tp = tp
"Number of true positives."
self.fp = fp
"Number of false positives."
self.tn = tn
"Number of true negatives."
self.fn = fn
"Number of false negatives."
@property
[docs] def total_positive(self):
"The total number of positive test cases."
return self.tp + self.fn
@property
[docs] def total_negative(self):
"The total number of negative test cases."
return self.tn + self.fp
def __cmp__(self, other):
"Comparison."
diff = self.sensitivity() - other.sensitivity()
if diff < 0.:
return -1
elif diff > 0.:
return 1.
else:
diff = other.specificity() - self.specificity()
if diff < 0.:
return -1
elif diff > 0.:
return 1.
else:
return 0
[docs] def distance(self, other):
"Measure of distance between points."
return (self.sensitivity()-other.sensitivity()) * (other.specificity()-self.specificity())
def __call__(self, truth, prediction):
"Updates this ROC calculator with one truth/prediction pair"
if prediction:
if truth: self.tp += 1
else: self.fp += 1
else:
if truth: self.fn += 1
else: self.tn += 1
def __add__(self, rhs):
"Add this RocCalculator to another and return the result."
result = RocCalculator()
result.tp = self.tp + rhs.tp
result.tn = self.tn + rhs.tn
result.fp = self.fp + rhs.fp
result.fn = self.fn + rhs.fn
return result
[docs] def normalise(self, rhs):
"Normalise this RocCalculator so that tp+tn+fp+fn=1."
sum = float(self.tp+self.tn+self.fp+self.fn)
self.tp /= sum
self.tn /= sum
self.fp /= sum
self.fn /= sum
@staticmethod
[docs] def always_predict_true():
"A RocCalculator for a predictor that always predicts True"
result = RocCalculator()
result.tp = result.fp = 1
result.tn = result.fn = 0
return result
@staticmethod
[docs] def always_predict_false():
"A RocCalculator for a predictor that always predicts False"
result = RocCalculator()
result.tp = result.fp = 0
result.tn = result.fn = 1
return result
[docs] def sensitivity(self):
"TP/(TP+FN)"
denominator = self.tp + self.fn
if denominator:
return float(self.tp)/denominator
else:
return 1.0
true_positive_rate = tpr = hit_rate = recall = sensitivity
[docs] def specificity(self):
"TN/(TN+FP)"
denominator = self.tn + self.fp
if denominator:
return float(self.tn)/denominator
else:
return 1.0
[docs] def false_positive_rate(self):
"FP/(TN+FP)"
return 1.0 - self.specificity()
fpr = false_positive_rate
[docs] def positive_predictive_value(self):
"TP/(TP+FP)"
denominator = self.tp + self.fp
if denominator:
return float(self.tp)/denominator
else:
return 1.0
precision = positive_predictive_value
[docs] def negative_predictive_value(self):
"TN/(TN+FN)"
denominator = self.tn + self.fn
if denominator:
return float(self.tn)/denominator
else:
return 1.0
[docs] def accuracy(self):
"(TP+TN)/(TP+TN+FP+FN)"
denominator = self.tp + self.tn + self.fn + self.fp
if denominator:
return float(self.tp+self.tn)/denominator
else:
return 1.0
[docs] def correlation_coefficient(self):
"(TP.TN-FN.FP)/sqrt((TP+FN)(TN+FP)(TP+FP)(TN+FN)) see: Burset & Guigo"
denominator = math.sqrt((self.tp+self.fn)*(self.tn+self.fn)*(self.tp+self.fp)*(self.tn+self.fp))
numerator = self.tp*self.tn-self.fn*self.fp
if denominator:
return numerator/denominator
else:
if 0.0 == numerator:
return 0.0
else:
return 1.0
def __str__(self):
return '''TP: %d; FP: %d; TN: %d; FN: %d
sensitivity: %.3f TP/(TP+FN)
specificity: %.3f TN/(TN+FP)
positive predictive value: %.3f TP/(TP+FP)
performance coefficient: %.3f TP/(TP+FN+FP)
correlation coefficient: %.3f (TP.TN-FN.FP)/sqrt((TP+FN)(TN+FP)(TP+FP)(TN+FN))''' % (
self.tp, self.fp, self.tn, self.fn,
self.sensitivity(),
self.specificity(),
self.positive_predictive_value(),
self.performance_coefficient(),
self.correlation_coefficient(),
)
[docs]def update_roc(roc, truth_prediction_iterable):
"for each (truth,prediction) in iterable, update the ROC calculator"
for truth, prediction in truth_prediction_iterable:
roc(truth, prediction)
[docs]def get_new_roc_parameter(rocs, for_specificity=True):
"""
Takes a sequence of (parameter, roc) tuples and returns a new parameter that should be tested
next.
It chooses this parameter by sorting the sequence and taking the mid-point between
the parameters with the largest absolute difference between their specificities or
sensitivities (depending on for_specificity parameter).
"""
rocs.sort()
statistic = for_specificity and RocCalculator.specificity or RocCalculator.sensitivity
diffs = [
(abs(statistic(rocs[i][1])-statistic(rocs[i+1][1])), (rocs[i][0]+rocs[i+1][0])/2)
for i
in xrange(len(rocs)-1)
]
return max(diffs)[1]
[docs]def generate_roc_points(rocs, sort_negative_first=True):
"""
Generate ROC points but sort negatives before positives at same
threshold if asked to.
This gives a step-function like ROC curve rather than a smoothed curve.
"""
warnings.warn('DEPRECATED: use infpy.roc.all_rocs_from_thresholds()', DeprecationWarning)
last = RocCalculator.always_predict_false()
yield last
for roc in chain(rocs, (RocCalculator.always_predict_true(),)):
if sort_negative_first:
# add another ROC point with TPR same as last ROC point but FPR
# same as this one
yield RocCalculator(last.tp, roc.fp, roc.tn, last.fn)
yield roc
last = roc
[docs]def plot_roc_points(rocs, **plot_kwds):
"""
Plots TPR versus FPR for the ROCs in rocs. Adds points at (0,0) and (1,1).
:param rocs: A sequence of ROCs.
:param plot_kwds: All extra keyword arguments are passed to the pylab.plot call.
:returns: The result of pylab.plot call.
"""
warnings.warn('DEPRECATED: use infpy.roc.plot_rocpoints()', DeprecationWarning)
from pylab import plot
extended_rocs = list(generate_roc_points(rocs))
tprs = map(RocCalculator.tpr, extended_rocs)
fprs = map(RocCalculator.fpr, extended_rocs)
return plot(fprs, tprs, **plot_kwds)
[docs]def plot_rocpoints(rocpoints, fillargs=None, **plot_kwds):
"""
Plots TPR versus FPR for the ROCs in rocpoints.
:param rocpoints: A sequence of ROCs.
:param plot_kwds: All extra keyword arguments are passed to the pylab.plot call.
:returns: The result of pylab.plot call.
"""
from pylab import plot, fill_between
tprs = map(RocCalculator.tpr, rocpoints)
fprs = map(RocCalculator.fpr, rocpoints)
if fillargs is not None:
fill_between(fprs, tprs, **fillargs)
return plot(fprs, tprs, **plot_kwds)
[docs]def plot_precision_versus_recall(rocs, **plot_kwds):
"""
Plots precision versus recall for the ROCs in rocs. Adds points at (0,1) and (1,0).
:param rocs: A sequence of ROCs.
:param plot_kwds: All extra keyword arguments are passed to the pylab.plot call.
:returns: The result of pylab.plot call.
"""
from pylab import plot
points = [(roc.recall(), roc.precision()) for roc in rocs]
#points.sort()
return plot(
[recall for recall, precision in points],
[precision for recall, precision in points],
**plot_kwds
)
[docs]def plot_precision_recall(roc_thresholds, recall_plot_kwds={}, precision_plot_kwds={}, plot_fn=None):
"""
Plots a precision-recall curve for the given ROCs.
:param roc_thresholds: A sequence of tuples (ROC, threshold).
:param recall_plot_kwds: Passed to the pylab.plot call for the recall.
:param precision_plot_kwds: Passed to the pylab.plot call for the precision.
:param plot_fn: Function used to plot. Use pylab.semilogx for log scale threshold axis.
:returns: The result of 2 pylab.plot calls as a tuple (recall, precision).
"""
if None == plot_fn:
from pylab import plot
plot_fn = plot
if 'label' not in recall_plot_kwds:
recall_plot_kwds['label'] = 'Recall'
if 'color' not in recall_plot_kwds:
recall_plot_kwds['color'] = 'blue'
if 'linestyle' not in recall_plot_kwds:
recall_plot_kwds['linestyle'] = ':'
recall_result = plot_fn(
[t for roc, t in roc_thresholds],
[roc.recall() for roc, t in roc_thresholds],
**recall_plot_kwds
)
if 'label' not in precision_plot_kwds:
precision_plot_kwds['label'] = 'Precision'
if 'color' not in precision_plot_kwds:
precision_plot_kwds['color'] = 'maroon'
if 'linestyle' not in precision_plot_kwds:
precision_plot_kwds['linestyle'] = '--'
precision_result = plot_fn(
[t for roc, t in roc_thresholds],
[roc.precision() for roc, t in roc_thresholds],
**precision_plot_kwds
)
return recall_result, precision_result
[docs]def auc(rocpoints):
"""
Calculate the area under the ROC points.
"""
sum = 0.
last = None
for point in rocpoints:
try:
sum += ((point.tpr() + last.tpr())
/ 2 * abs(last.fpr() - point.fpr()))
except AttributeError:
pass
last = point
return sum
[docs]def area_under_curve(rocs, include_0_0=True, include_1_1=True):
"""
:param rocs: The ROC points.
:param include_0_0: True to include extra point for origin of ROC curve.
:param include_1_1: True to include extra point at (1,1) in ROC curve.
:returns: The area under the ROC curve given by the ROC points.
"""
warnings.warn('DEPRECATED: use infpy.roc.auc()', DeprecationWarning)
x_axis = []
y_axis = []
if include_0_0:
x_axis.append(0.)
y_axis.append(0.)
x_axis.extend(1. - roc.specificity() for roc in rocs)
y_axis.extend(roc.sensitivity() for roc in rocs)
if include_1_1:
x_axis.append(1.)
y_axis.append(1.)
last_x, last_y = None, None
area = 0.
for x, y in zip(x_axis, y_axis):
if last_x != None: # if not first point
area += (x-last_x) * (y+last_y) / 2
last_x, last_y = x, y
return area
[docs]def plot_random_classifier(**kwargs):
"""Draw a random classifier on a ROC plot. Black dashed line by default."""
from pylab import plot
if 'color' not in kwargs:
kwargs['color'] = 'black'
if 'linestyle' not in kwargs:
kwargs['linestyle'] = ':'
plot(
[0,1],
[0,1],
**kwargs
)
[docs]def label_plot():
"""Label the x and y axes of a ROC plot."""
import pylab as P
P.xlabel('1 - specificity: 1-TN/(TN+FP)')
P.ylabel('sensitivity: TP/(TP+FN)')
[docs]def label_precision_versus_recall():
"""Label the x and y axes of a precision versus recall plot."""
import pylab as P
P.xlabel('Recall: TP/(TP+FN)')
P.ylabel('Precision: TP/(TP+FP)')
P.xlim(0, 1)
P.ylim(0, 1)
[docs]def label_precision_recall():
"""Label the x and y axes of a precision-recall plot."""
import pylab as P
P.xlabel('threshold')
P.ylabel('precision/recall')
[docs]def count_threshold_classifications(thresholds, value):
"""
Take a list of thresholds (in sorted order) and count how many would be classified positive and negative at the given value.
:returns: (num_positive, num_negative).
"""
from bisect import bisect_right
idx = bisect_right(thresholds, value)
return len(thresholds) - idx, idx
[docs]def roc_for_threshold(positive_thresholds, negative_thresholds, value):
"""
Take lists of positive and negative thresholds (in sorted order)
and calculate a ROC point for the given value.
"""
tp, fn = count_threshold_classifications(positive_thresholds, value)
fp, tn = count_threshold_classifications(negative_thresholds, value)
return RocCalculator(tp, fp, tn, fn)
[docs]def make_roc_from_threshold_fn(positive_thresholds, negative_thresholds):
":returns: A function that calculates a ROC point given a threshold."
def local_roc_for_threshold(value):
return roc_for_threshold(positive_thresholds, negative_thresholds, value)
return local_roc_for_threshold
[docs]def rocs_from_thresholds(positive_thresholds, negative_thresholds, num_points=32):
"""
Takes 2 sorted lists: one list is of the thresholds required to classify the positive examples as positive
and the other list is of the thresholds required to classify the negative examples as positive.
:returns: A list of ROC points.
"""
warnings.warn('DEPRECATED: use infpy.roc.all_rocs_from_thresholds()', DeprecationWarning)
min_threshold = min(positive_thresholds[0], negative_thresholds[0])
max_threshold = max(positive_thresholds[-1], negative_thresholds[-1])
rocs = map(
make_roc_from_threshold_fn(positive_thresholds, negative_thresholds),
N.linspace(min_threshold, max_threshold, num_points)[::-1]
)
return rocs
[docs]def pick_roc_thresholds(roc_for_threshold_fn, min_threshold, max_threshold, num_points=32):
"""
Tries to pick thresholds to give a smooth ROC curve.
:returns: A list of (roc point, threshold) tuples.
"""
def add_threshold(threshold):
"Calculate the ROC point and add to list."
rocs.append((roc_for_threshold_fn(threshold), threshold))
rocs.sort()
def compare_2_points(x1, x2):
"Compare 2 ROC points to see how far apart they are."
rp1, t1 = x1
rp2, t2 = x2
return (rp1.distance(rp2), (t1+t2)/2.)
rocs = []
add_threshold(min_threshold)
add_threshold(max_threshold)
while(len(rocs) < num_points):
# find best new threshold
biggest_distance, new_threshold = max(map(compare_2_points, rocs[:-1], rocs[1:]))
add_threshold(new_threshold)
return rocs
[docs]def create_rocs_from_thresholds(positive_thresholds, negative_thresholds, num_points=32):
"""
Takes 2 sorted lists: one list is of the thresholds required to classify the positive examples as positive
and the other list is of the thresholds required to classify the negative examples as positive.
:returns: A list of tuples (ROC point, threshold).
"""
warnings.warn('DEPRECATED: use infpy.roc.all_rocs_from_thresholds()', DeprecationWarning)
return pick_roc_thresholds(
make_roc_from_threshold_fn(positive_thresholds, negative_thresholds),
min_threshold=min(positive_thresholds[0], negative_thresholds[0]),
max_threshold=max(positive_thresholds[-1], negative_thresholds[-1]),
num_points=num_points
)
[docs]def picked_rocs_from_thresholds(positive_thresholds, negative_thresholds, num_points=32):
"""
Takes 2 sorted lists: one list is of the thresholds required to classify the positive examples as positive
and the other list is of the thresholds required to classify the negative examples as positive.
:returns: A list of ROC points.
"""
warnings.warn('DEPRECATED: use infpy.roc.all_rocs_from_thresholds()', DeprecationWarning)
return [roc for roc, t in create_rocs_from_thresholds(positive_thresholds, negative_thresholds, num_points=num_points)]
[docs]def all_rocs_from_thresholds(
positive_thresholds,
negative_thresholds,
negative_first=True
):
"""
Takes 2 sorted lists (smallest to largest):
one list is of the thresholds required to classify the positive examples
as positive and the other list is of the thresholds required to classify
the negative examples as positive.
:returns: Yields all the ROC points. Note that they are returned in the
opposite order to some of the other methods in this module.
"""
import heapq
#
# How many do we have in total?
#
total_positive = len(positive_thresholds)
total_negative = len(negative_thresholds)
if not total_positive:
raise ValueError('Need to have at least one positive prediction')
if not total_negative:
raise ValueError('Need to have at least one negative prediction')
#
# At the lowest threshold, everything is a positive prediction
#
tp = total_positive
fp = total_negative
yield RocCalculator(tp, fp, 0, 0)
#
# Iterate through the merged thresholds.
#
if not negative_first:
keyfn = lambda x: x[0] # just group by threshold not positive/negative
else:
keyfn = None
for key, group in groupby(
heapq.merge(
((t, 1) for t in negative_thresholds),
((t, 0) for t in positive_thresholds)),
key=keyfn
):
#
# Update our number of true or false positive predictions
#
for threshold, isnegative in group:
if isnegative:
fp -= 1
else:
tp -= 1
#
# Yield the ROC point for this group
#
yield RocCalculator(tp, fp,
total_negative-fp, total_positive-tp)
# everything should be classified as false after seeing all the thresholds
assert tp == 0
assert fp == 0
[docs]def resize_negative_examples(positive_thresholds, negative_thresholds, num_negative=50):
"""
Reduce the positive and negative thresholds such that there are just 50 (or num_negative) negative examples.
The positive thresholds are trimmed accordingly.
"""
if num_negative > len(negative_thresholds):
raise RuntimeError('Not enough negative examples (%d). Requested %d' % (len(negative_thresholds), num_negative))
import bisect
negative_thresholds = negative_thresholds[-num_negative:]
threshold = negative_thresholds[0]
positive_cutoff = bisect.bisect(positive_thresholds, threshold)
positive_thresholds = positive_thresholds[positive_cutoff:]
return positive_thresholds, negative_thresholds
[docs]def auc50_wrong(positive_thresholds, negative_thresholds, num_negative=50, num_points=32):
"""
Calculate the AUC50 as in Gribskov & Robinson 'Use of ROC analysis to evaluate sequence pattern matching'
"""
warnings.warn('DEPRECATED: use infpy.roc.auc50_from_rocpoints()', DeprecationWarning)
if num_negative > len(negative_thresholds):
raise RuntimeError('Not enough negative examples (%d). Requested %d' % (len(negative_thresholds), num_negative))
threshold = negative_thresholds[-num_negative]
roc_thresholds = pick_roc_thresholds(
make_roc_from_threshold_fn(positive_thresholds, negative_thresholds),
min_threshold=threshold,
max_threshold=max(positive_thresholds[-1], negative_thresholds[-1]),
num_points=num_points
)
auc50 = area_under_curve([roc for roc, t in roc_thresholds], include_1_1=False)
return auc50, roc_thresholds
[docs]def auc50(
positive_thresholds,
negative_thresholds,
num_negative=50,
num_points=32):
"""
Calculate the AUC50 as in Gribskov & Robinson:
'Use of ROC analysis to evaluate sequence pattern matching'
"""
warnings.warn('DEPRECATED: use infpy.roc.auc50_from_rocpoints()', DeprecationWarning)
if num_negative > len(negative_thresholds):
raise RuntimeError('Not enough negative examples (%d). Requested %d' % (len(negative_thresholds), num_negative))
roc_thresholds = pick_roc_thresholds(
make_roc_from_threshold_fn(positive_thresholds, negative_thresholds[-num_negative:]),
min_threshold=min(positive_thresholds[0], negative_thresholds[0]),
max_threshold=max(positive_thresholds[-1], negative_thresholds[-1]),
num_points=num_points
)
auc50 = area_under_curve([roc for roc, t in roc_thresholds])
return auc50, roc_thresholds
[docs]def plot_rocpoint(rocpoint, **plotargs):
"""Plot a single rocpoint. Typically used to indicate where the last point for the AUC50
calculation is."""
from pylab import plot
plot(
[rocpoint.fpr()],
[rocpoint.tpr()],
color='black',
marker='x',
markersize=6,
**plotargs
)
[docs]def bisect_rocs(rocpoints, predicate, start=0, end=None):
"""Return the index into rocpoints for first rocpoint with
predicate(rocpoint) is True and start <= index < end. Assumes
rocpoints are sorted w.r.t. predicate.
"""
if end is None:
end = len(rocpoints)
if end <= start:
raise ValueError('Start (%d) must be less than end (%d).' % (start, end))
new_index = (start + end) / 2
if start == new_index:
return end # We narrowed down the range completely
else:
if predicate(rocpoints[new_index]):
# Look in lower half
return bisect_rocs(rocpoints, predicate, start, new_index)
else:
# Look in upper half
return bisect_rocs(rocpoints, predicate, new_index, end)
[docs]def restrict_false_positives(rocpoints, max_fp=50):
"""
Yield the ROC points while the number of true negatives is less than
max_tn.
"""
last = None
for point in rocpoints:
# check the rocpoints are in the correct sorted order
try:
if last.fp > point.fp:
raise ValueError(
'ROC points should be ordered by increasing '
'false positives')
except AttributeError:
pass
if point.fp <= max_fp:
yield point
if point.fp == max_fp:
break
else:
yield RocCalculator(
point.tp, max_fp, point.tn - max_fp + point.fp, point.fn)
break
last = point
[docs]def auc50_from_rocpoints(
rocpoints,
max_fp=50):
"""
Calculate the AUC50 as in Gribskov & Robinson:
'Use of ROC analysis to evaluate sequence pattern matching'
"""
return auc(restrict_false_positives(rocpoints, max_fp))
if '__main__' == __name__:
import numpy.random as R, pylab as P
def check_points(points):
# check increasing in fpr and tpr
last = None
for point in points:
logging.debug(point)
try:
assert point.fpr() <= last.fpr()
assert point.tpr() <= last.tpr()
except AttributeError:
pass
last = point
# check endpoints
assert 1 == points[ 0].fpr() == points[ 0].tpr()
assert 0 == points[-1].fpr() == points[-1].tpr()
P.close('all')
P.figure()
# Try one positive, one negative, negative first
positive = [.5]
negative = [.5]
logging.info('Positive=%s; Negative=%s', positive, negative)
points = list(all_rocs_from_thresholds(positive, negative, True))
check_points(points)
logging.info(
'Got %4d ROC points, AUC=%.3f, AUC50=%.3f',
len(points), auc(points), auc50_from_rocpoints(points[::-1]))
assert 3 == len(points)
# Try one positive, one negative
positive = [.5]
negative = [.5]
logging.info('Positive=%s; Negative=%s', positive, negative)
points = list(all_rocs_from_thresholds(positive, negative, False))
check_points(points)
logging.info(
'Got %4d ROC points, AUC=%.3f, AUC50=%.3f',
len(points), auc(points), auc50_from_rocpoints(points[::-1]))
# for point in points:
# logging.info(point)
assert 2 == len(points)
# Try one positive, one negative
positive = [.5]
negative = [.6]
logging.info('Positive=%s; Negative=%s', positive, negative)
points = list(all_rocs_from_thresholds(positive, negative, False))
check_points(points)
logging.info(
'Got %4d ROC points, AUC=%.3f, AUC50=%.3f',
len(points), auc(points), auc50_from_rocpoints(points[::-1]))
# for point in points:
# logging.info(point)
assert 3 == len(points)
# Try a few
R.seed(1)
positive1 = R.normal(size=500, loc=.6, scale=.4)
positive1.sort()
positive2 = R.normal(size=500, loc=.8, scale=.8)
positive2.sort()
negative = R.normal(size=700, loc=.2, scale=.4)
negative.sort()
golden_ratio = (1 + math.sqrt(5)) / 2
P.figure(figsize=(6, 6))
plot_random_classifier()
colors = ['green', 'blue']
hatches = ['/', '\\']
for idx, positive in enumerate((positive1, positive2)):
color = colors[idx]
points = list(all_rocs_from_thresholds(positive, negative, True))
check_points(points)
method = idx + 1
label = 'method %d' % method
logging.info(
'%s: Got %4d ROC points, AUC=%.3f, AUC50=%.3f',
label,
len(points),
auc(points),
auc50_from_rocpoints(points[::-1]))
plot_rocpoints(points, label=label, color=color)
restricted_points = list(restrict_false_positives(points[::-1]))
plot_rocpoints(
restricted_points,
label=None,
fillargs={
'alpha' : 0.3,
#'hatch' : hatches[idx],
'edgecolor' : (0,0,0,0),
'facecolor' : color,
#'facecolor' : (0,0,0,0),
},
color=(0,0,0,0),
)
P.legend(loc='lower right')
label_plot()
P.savefig('output/ROC.eps')
P.savefig('output/ROC.png')
P.savefig('output/ROC.pdf')
P.close('all')