# Andre Anjos <andre.anjos@idiap.ch>
# Wed 11 May 2011 09:16:39 CEST
"""The ScoreToolKit (or simply "stk") provides functionality to load TABULA
RASA conformant score files, for either plotting DET curves or for the
validation of multi-file score matching.
"""
import math
import numpy
# Values required for linear scale => gaussian scale conversion at __ppndf__()
__SPLIT__ = 0.42
__A0__ = 2.5066282388
__A1__ = -18.6150006252
__A2__ = 41.3911977353
__A3__ = -25.4410604963
__B1__ = -8.4735109309
__B2__ = 23.0833674374
__B3__ = -21.0622410182
__B4__ = 3.1308290983
__C0__ = -2.7871893113
__C1__ = -2.2979647913
__C2__ = 4.8501412713
__C3__ = 2.3212127685
__D1__ = 3.5438892476
__D2__ = 1.6370678189
__EPS__ = 2.2204e-16
def load_file(filename, no_labels=False):
"""Loads a score set from a single file to memory.
[docs]
Verifies that all fields are correctly placed and contain valid fields.
Returns a python list of tuples containg the following fields:
[0]
claimed identity (string)
[1]
model label (string)
[2]
real identity (string)
[3]
test label (string)
[4]
score (float)
"""
retval = []
for i, l in enumerate(open(filename, 'rt')):
s = l.strip()
if len(s) == 0 or s[0] == '#': continue #empty or comment
field = [k.strip() for k in s.split()]
if len(field) != 5:
raise SyntaxError, 'Line %d of file "%s" is invalid: %s' % \
(i, filename, l)
try:
score = float(field[4])
if no_labels: #only useful for plotting
t = (field[0], None, field[2], None, score)
else:
t = (field[0], field[1], field[2], field[3], score)
retval.append(t)
except:
raise SyntaxError, 'Cannot convert score to float at line %d of file "%s": %s' % (i, filename, l)
return retval
def split(data):
"""Splits the input tuple list (as returned by load_file()) into positives
[docs] and negative scores.
Returns 2 numpy arrays as a tuple with (negatives, positives)
"""
return (numpy.array([k[4] for k in data if k[0] != k[2]], dtype='float64'),
numpy.array([k[4] for k in data if k[0] == k[2]], dtype='float64'))
def farfrr(negatives, positives, threshold):
"""Calculates the FAR and FRR for a given set of positives and negatives and
[docs] a threshold"""
far = len(negatives[negatives >= threshold]) / float(len(negatives))
frr = len(positives[positives < threshold]) / float(len(positives))
#print threshold, far, frr
return (far, frr)
def evalROC(negatives, positives, points):
"""Evaluates the ROC curve.
[docs]
This method evaluates the ROC curve given a set of positives and negatives,
returning two numpy arrays containing the FARs and the FRRs.
"""
minval = min(min(negatives), min(positives))
maxval = max(max(negatives), max(positives))
step = (maxval - minval) / (points-1)
rng = numpy.arange(minval, maxval+step, step)
points = len(rng)
far = numpy.zeros((points,), dtype='float64')
frr = numpy.zeros((points,), dtype='float64')
for i, threshold in enumerate(rng):
if i<points:
far[i], frr[i] = farfrr(negatives, positives, threshold)
return far, frr
def __ppndf__(p):
"""Converts a linear scale to a "Gaussian" scale
Method based on the NIST evaluation code (DETware version 2.1).
"""
if p >= 1.0: p = 1 - __EPS__
if p <= 0.0: p = __EPS__
q = p - 0.5
if abs(q) <= __SPLIT__:
r = q * q
retval = q * (((__A3__ * r + __A2__) * r + __A1__) * r + __A0__) \
/ ((((__B4__ * r + __B3__) * r + __B2__) * r + __B1__) * r + 1.0)
else:
if q > 0.0: r = 1.0 - p
else: r = p
if r <= 0.0: raise RuntimeError, 'ERROR Found r = %g\n' % r
r = math.sqrt((-1.0) * math.log(r))
retval = (((__C3__ * r + __C2__) * r + __C1__) * r + __C0__) \
/ ((__D2__ * r + __D1__) * r + 1.0)
if (q < 0): retval *= -1.0
return retval
def evalDET(negatives, positives, points):
"""Evaluates the DET curve.
[docs]
This method evaluates the DET curve given a set of positives and negatives,
returning two numpy arrays containing the FARs and the FRRs.
"""
def __ppndf_array__(arr):
retval = numpy.zeros(arr.shape, dtype='float64')
for i, p in enumerate(arr): retval[i] = __ppndf__(p)
return retval
far, frr = evalROC(negatives, positives, points)
return (__ppndf_array__(far), __ppndf_array__(frr))
def plotDET(negatives, positives, filename='det.pdf', points=100,
limits=None, title='DET Curve', labels=None, colour=False):
[docs] """Plots Detection Error Trade-off (DET) curve
Keyword parameters:
positives
numpy.array of positive class scores in float64 format
negatives
numpy.array of negative class scores in float64 format
filename
the output filename where to save the plot. If not specified, we output
to 'det.pdf'
points
an (optional) number of points to use for the plot. Defaults to 100.
limits
an (optional) tuple containing 4 elements that determine the maximum and
minimum values to plot. Values have to exist in the internal
desiredLabels variable.
title
an (optional) string containg a title to be inprinted on the top of the
plot
labels
an (optional) list of labels for a legend. If None or empty, the legend
is suppressed
colour
flag determining if the plot is coloured or monochrome. By default we
plot in monochrome scale.
"""
import matplotlib.pyplot as mpl
figure = mpl.gcf()
figure.set_figheight(figure.get_figheight()*1.3)
desiredTicks = ["0.00001","0.00002","0.00005","0.0001","0.0002","0.0005","0.001","0.002","0.005","0.01","0.02","0.05","0.1","0.2","0.4","0.6","0.8","0.9","0.95","0.98","0.99","0.995","0.998","0.999","0.9995","0.9998","0.9999","0.99995","0.99998","0.99999"]
desiredLabels = ["0.001","0.002","0.005","0.01","0.02","0.05","0.1","0.2","0.5","1","2","5","10","20","40","60","80","90","95","98","99","99.5","99.8","99.9","99.95","99.98","99.99","99.995","99.998","99.999"]
# Available styles: please note that we plot up to the number of styles
# available. So, for coloured plots, we can go up to 6 lines in a single
# plot. For grayscaled ones, up to 12. If you need more plots just extend the
# list bellow.
colourStyle = [
((0, 0, 0), '-', 1), #black
((0, 0, 1.0), '--', 1), #blue
((0.8, 0.0, 0.0), '-.', 1), #red
((0, 0.6, 0.0), ':', 1), #green
((0.5, 0.0, 0.5), '-', 1), #magenta
((0.3, 0.3, 0.0), '--', 1), #orange
]
grayStyle = [
((0, 0, 0), '-', 1), #black
((0, 0, 0), '--', 1), #black
((0, 0, 0), '-.',1), #black
((0, 0, 0), ':', 1), #black
((0.3, 0.3, 0.3), '-', 1), #gray
((0.3, 0.3, 0.3), '--', 1), #gray
((0.3, 0.3, 0.3), '-.', 1), #gray
((0.3, 0.3, 0.3), ':', 1), #gray
((0.6, 0.6, 0.6), '-', 2), #lighter gray
((0.6, 0.6, 0.6), '--', 2), #lighter gray
((0.6, 0.6, 0.6), '-.', 2), #lighter gray
((0.6, 0.6, 0.6), ':', 2), #lighter gray
]
if not limits: limits = ('0.001', '99.999', '0.001', '99.999')
# check limits
for k in limits:
if k not in desiredLabels:
raise SyntaxError, \
'Unsupported limit %s. Please use one of %s' % (k, desiredLabels)
if colour: style = colourStyle
else: style = grayStyle
if labels:
for neg, pos, lab, sty in zip(negatives, positives, labels, style):
ppfar, ppfrr = evalDET(neg, pos, points)
mpl.plot(ppfrr, ppfar, label=lab, color=sty[0], linestyle=sty[1],
linewidth=sty[2])
else:
for neg, pos, sty in zip(negatives, positives, style):
ppfar, ppfrr = evalDET(neg, pos, points)
mpl.plot(ppfrr, ppfar, color=sty[0], linestyle=sty[1], linewidth=sty[2])
fr_minIndex = desiredLabels.index(limits[0])
fr_maxIndex = desiredLabels.index(limits[1])
fa_minIndex = desiredLabels.index(limits[2])
fa_maxIndex = desiredLabels.index(limits[3])
#convert into DET scale
pticks = [__ppndf__(float(v)) for v in desiredTicks]
ax = mpl.gca()
mpl.axis([pticks[fr_minIndex], pticks[fr_maxIndex],
pticks[fa_minIndex], pticks[fa_maxIndex]])
ax.set_xticks(pticks[fr_minIndex:fr_maxIndex])
ax.set_xticklabels(desiredLabels[fr_minIndex:fr_maxIndex],
size='x-small', rotation='vertical')
ax.set_yticks(pticks[fa_minIndex:fa_maxIndex])
ax.set_yticklabels(desiredLabels[fa_minIndex:fa_maxIndex],
size='x-small')
if title: mpl.title(title)
mpl.grid(True)
mpl.xlabel('False Rejection Rate [in %]')
mpl.ylabel('False Acceptance Rate [in %]')
if labels: mpl.legend()
mpl.savefig(filename, dpi=300) #saves current figure to file
def checkModalities(data1, filename1, data2, filename2, presorted=False):
"""Double-checks score files for fusion.
[docs]
This method checks two score files to make sure they match w.r.t. to the
number of clients, imposter and models. It is equivalent to making sure the
first 4 columns of such files contain the same fields, after ordering.
Parameters:
data1
The pre-loaded data set using load_file()
filename1
The first score file name (string)
data2
The (second) pre-loaded data set using load_file()
filename2
The second score file name (string)
presorted
A flag indicating if the files have been pre-sorted (boolean)
Here is how to sort your score files using shell utilities ``sort`` and
``uniq``:
.. code-block:: sh
$ sort my-scores.txt | uniq > my-sorted-scores.txt
Returns None.
"""
def checkModalitiesSorted(s1, f1, s2, f2):
if len(s1) != len(s2):
raise SyntaxError, 'Lengths differ between "%s" and "%s"' % (f1, f2)
for i, v1 in enumerate(s1):
if v1[0:4] != s2[i][0:4]:
raise SyntaxError, 'Entry "%s" is not available on "%s"' % (v1, f2)
if not presorted:
data1.sort()
data2.sort()
checkModalitiesSorted(data1, filename1, data2, filename2)