#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Thu May 24 10:41:42 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
"""This module provides the Dataset interface allowing the user to query the
LFW database.
"""
import six
from bob.db.base import utils
from .models import *
from sqlalchemy.orm import aliased
from .driver import Interface
import bob.db.base
SQLITE_FILE = Interface().files()[0]
[docs]class Database(bob.db.base.SQLiteDatabase):
"""The dataset class opens and maintains a connection opened to the Database.
It provides many different ways to probe for the characteristics of the data
and for the data itself inside the database.
"""
def __init__(self, original_directory=None, original_extension='.jpg', annotation_type=None):
# call base class constructor
super(Database, self).__init__(SQLITE_FILE, File,
original_directory, original_extension)
self.m_valid_protocols = ('view1', 'fold1', 'fold2', 'fold3',
'fold4', 'fold5', 'fold6', 'fold7', 'fold8', 'fold9', 'fold10')
self.m_valid_groups = ('world', 'dev', 'eval')
self.m_valid_purposes = ('enroll', 'probe')
self.m_valid_classes = ('matched', 'client', 'unmatched', 'impostor')
self.m_subworld_counts = {'onefolds': 1, 'twofolds': 2, 'threefolds': 3,
'fourfolds': 4, 'fivefolds': 5, 'sixfolds': 6, 'sevenfolds': 7}
self.m_valid_types = ('restricted', 'unrestricted')
self.m_valid_annotation_types = ('idiap', 'funneled')
if annotation_type is not None:
self.m_annotation_type = self.check_parameter_for_validity(
annotation_type, "annotation type", self.m_valid_annotation_types)
else:
self.m_annotation_type = None
def __eval__(self, fold):
return int(fold[4:])
def __dev__(self, eval):
# take the two parts of the training set (the ones before the eval set)
# for dev
return ((eval + 7) % 10 + 1, (eval + 8) % 10 + 1)
def __dev_for__(self, fold):
return ["fold%d" % f for f in self.__dev__(self.__eval__(fold))]
def __world_for__(self, fold, subworld):
# the training sets for each fold are composed of all folds
# except the given one and the previous
eval = self.__eval__(fold)
dev = self.__dev__(eval)
world_count = self.m_subworld_counts[subworld]
world = []
for i in range(world_count):
world.append((eval + i) % 10 + 1)
return ["fold%d" % f for f in world]
[docs] def protocol_names(self):
"""Returns the names of the valid protocols."""
return self.m_valid_protocols
[docs] def groups(self, protocol=None):
"""Returns the groups, which are available in the database."""
if protocol != 'view1':
return self.m_valid_groups
else:
return self.m_valid_groups[:2]
[docs] def subworld_names(self, protocol=None):
"""Returns all valid sub-worlds for the fold.. protocols; for view1 an empty list is returned."""
if protocol != 'view1':
return self.m_subworld_counts.keys()
else:
return []
[docs] def world_types(self):
"""Returns the valid types of worlds: ('restricted', 'unrestricted')."""
return self.m_valid_types
[docs] def annotation_types(self):
"""Queries the database for the available types of annotations."""
s = set([a.annotation_type for a in self.query(Annotation)])
return [str(t) for t in s]
[docs] def clients(self, protocol=None, groups=None, subworld='sevenfolds', world_type='unrestricted'):
"""Returns a list of Client objects for the specific query by the user.
Keyword Parameters:
protocol
The protocol to consider; one of: ('view1', 'fold1', ..., 'fold10'), or None
groups
The groups to which the clients belong; one or several of: ('world', 'dev', 'eval')
Note: the 'eval' group does not exist for protocol 'view1'.
subworld
The subset of the training data. Has to be specified if groups includes 'world'
and protocol is one of 'fold1', ..., 'fold10'.
It might be exactly one of ('onefolds', 'twofolds', ..., 'sevenfolds').
Ignored for group 'dev' and 'eval'.
world_type
One of ('restricted', 'unrestricted'). If 'restricted' (the default), only the
clients that are used in one of the training pairs are returned. For 'unrestricted',
all training people are returned.
Ignored for group 'dev' and 'eval'.
Returns: A list containing all Client objects which have the desired properties.
"""
protocols = self.check_parameters_for_validity(
protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, 'group', self.m_valid_groups)
if subworld != None:
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
world_type = self.check_parameter_for_validity(
world_type, 'training type', self.m_valid_types)
queries = []
# List of the clients
for protocol in protocols:
if protocol == 'view1':
if 'world' in groups:
if world_type == 'restricted':
queries.append(
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol == 'train').
order_by(Client.id))
else:
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == 'train').
order_by(Client.id))
if 'dev' in groups:
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == 'test').
order_by(Client.id))
else:
if 'world' in groups:
# select training set for the given fold
trainset = self.__world_for__(protocol, subworld)
if world_type == 'restricted':
queries.append(
self.query(Client).join(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol.in_(trainset)).
order_by(Client.id))
else:
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol.in_(trainset)).
order_by(Client.id))
if 'dev' in groups:
# select development set for the given fold
devset = self.__dev_for__(protocol)
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol.in_(devset)).
order_by(Client.id))
if 'eval' in groups:
queries.append(
self.query(Client).join(File).join(People).
filter(People.protocol == protocol).
order_by(Client.id))
# all queries are made; now collect the clients
retval = []
for query in queries:
for client in query:
retval.append(client)
return self.uniquify(retval)
[docs] def models(self, protocol=None, groups=None):
"""Returns a list of File objects (there are multiple models per client) for the specific query by the user.
For the 'dev' and 'eval' groups, the first element of each pair is extracted.
Keyword Parameters:
protocol
The protocol to consider; one of: ('view1', 'fold1', ..., 'fold10'), or None
groups
The groups to which the clients belong; one or several of: ('dev', 'eval')
The 'eval' group does not exist for protocol 'view1'.
Returns: A list containing all File objects which have the desired properties.
"""
protocols = self.check_parameters_for_validity(
protocol, 'protocol', self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, 'group', ('dev', 'eval'))
# the restricted case...
queries = []
# List of the models
for protocol in protocols:
if protocol == 'view1':
if 'dev' in groups:
queries.append(\
# enroll files
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).\
filter(Pair.protocol == 'test'))
else:
if 'dev' in groups:
# select development set for the given fold
devset = self.__dev_for__(protocol)
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'eval' in groups:
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
# all queries are made; now collect the files
retval = []
for query in queries:
retval.extend([file for file in query])
return self.uniquify(retval)
[docs] def model_ids(self, protocol=None, groups=None):
"""Returns a list of model ids for the specific query by the user.
For the 'dev' and 'eval' groups, the first element of each pair is extracted.
Keyword Parameters:
protocol
The protocol to consider; one of: ('view1', 'fold1', ..., 'fold10'), or None
groups
The groups to which the clients belong; one or several of: ('dev', 'eval')
The 'eval' group does not exist for protocol 'view1'.
Returns: A list containing all model ids which have the desired properties.
"""
return [file.id for file in self.models(protocol, groups)]
[docs] def get_client_id_from_file_id(self, file_id, **kwargs):
"""Returns the client_id (real client id) attached to the given file_id
Keyword Parameters:
file_id
The file_id to consider
Returns: The client_id attached to the given file_id
"""
self.assert_validity()
q = self.query(File).\
filter(File.id == file_id)
assert q.count() == 1
return q.first().client_id
[docs] def get_client_id_from_model_id(self, model_id, **kwargs):
"""Returns the client_id (real client id) attached to the given model id
Keyword Parameters:
model_id
The model to consider
type
One of ('restricted', 'unrestricted'). If the type 'restricted' is given,
model_ids will be handled as file ids, if type is 'unrestricted', model ids
will be client ids.
Returns: The client_id attached to the given model
"""
# since there is one model per file, we can re-use the function above.
return self.get_client_id_from_file_id(model_id)
[docs] def objects(self, protocol=None, model_ids=None, groups=None, purposes=None, subworld='sevenfolds', world_type='unrestricted'):
"""Returns a list of File objects for the specific query by the user.
Keyword Parameters:
protocol
The protocol to consider ('view1', 'fold1', ..., 'fold10'), or None
groups
The groups to which the objects belong ('world', 'dev', 'eval')
purposes
The purposes of the objects ('enroll', 'probe')
subworld
The subset of the training data. Has to be specified if groups includes 'world'
and protocol is one of 'fold1', ..., 'fold10'.
It might be exactly one of ('onefolds', 'twofolds', ..., 'sevenfolds').
world_type
One of ('restricted', 'unrestricted'). If 'restricted', only the files that
are used in one of the training pairs are used. For 'unrestricted', all
files of the training people are returned.
model_ids
Only retrieves the objects for the provided list of model ids.
If 'None' is given (this is the default), no filter over the model_ids is performed.
Note that the combination of 'world' group and 'model_ids' should be avoided.
Returns: A list of File objects considering all the filtering criteria.
"""
protocols = self.check_parameters_for_validity(
protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, "group", self.m_valid_groups)
purposes = self.check_parameters_for_validity(
purposes, "purpose", self.m_valid_purposes)
world_type = self.check_parameter_for_validity(
world_type, 'training type', self.m_valid_types)
if subworld != None:
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
if(isinstance(model_ids, six.string_types)):
model_ids = (model_ids,)
queries = []
probe_queries = []
file_alias = aliased(File)
for protocol in protocols:
if protocol == 'view1':
if 'world' in groups:
# training files of view1
if world_type == 'restricted':
queries.append(
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol == 'train'))
else:
queries.append(
self.query(File).join(People).
filter(People.protocol == 'train'))
if 'dev' in groups:
# test files of view1
if 'enroll' in purposes:
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == 'test'))
if 'probe' in purposes:
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, Pair.enroll_file_id == file_alias.id)).
filter(Pair.protocol == 'test'))
else:
# view 2
if 'world' in groups:
# world set of current fold of view 2
trainset = self.__world_for__(protocol, subworld)
if world_type == 'restricted':
queries.append(
self.query(File).join((Pair, or_(File.id == Pair.enroll_file_id, File.id == Pair.probe_file_id))).
filter(Pair.protocol.in_(trainset)))
else:
queries.append(
self.query(File).join(People).
filter(People.protocol.in_(trainset)))
if 'dev' in groups:
# development set of current fold of view 2
devset = self.__dev_for__(protocol)
if 'enroll' in purposes:
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'probe' in purposes:
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, file_alias.id == Pair.enroll_file_id)).
filter(Pair.protocol.in_(devset)))
if 'eval' in groups:
# evaluation set of current fold of view 2; this is the REAL fold
if 'enroll' in purposes:
queries.append(
self.query(File).join((Pair, File.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
if 'probe' in purposes:
probe_queries.append(
self.query(File).
join((Pair, File.id == Pair.probe_file_id)).
join((file_alias, file_alias.id == Pair.enroll_file_id)).
filter(Pair.protocol == protocol))
retval = []
for query in queries:
if model_ids and len(model_ids):
query = query.filter(File.id.in_(model_ids))
retval.extend([file for file in query])
for query in probe_queries:
if model_ids and len(model_ids):
query = query.filter(file_alias.id.in_(model_ids))
for probe in query:
retval.append(probe)
return self.uniquify(retval)
[docs] def pairs(self, protocol=None, groups=None, classes=None, subworld='sevenfolds'):
"""Queries a list of Pair's of files.
Keyword Parameters:
protocol
The protocol to consider ('view1', 'fold1', ..., 'fold10')
groups
The groups to which the objects belong ('world', 'dev', 'eval')
classes
The classes to which the pairs belong ('matched', 'unmatched'), or ('client', 'impostor')
subworld
The subset of the training data. Has to be specified if groups includes 'world'
and protocol is one of 'fold1', ..., 'fold10'.
It might be exactly one of ('onefolds', 'twofolds', ..., 'sevenfolds').
Returns: A list of Pair's considering all the filtering criteria.
"""
def default_query():
return self.query(Pair).\
join((File1, File1.id == Pair.enroll_file_id)).\
join((File2, File2.id == Pair.probe_file_id))
protocol = self.check_parameter_for_validity(
protocol, "protocol", self.m_valid_protocols)
groups = self.check_parameters_for_validity(
groups, "group", self.m_valid_groups)
classes = self.check_parameters_for_validity(
classes, "class", self.m_valid_classes)
if subworld != None:
subworld = self.check_parameter_for_validity(
subworld, 'sub-world', list(self.m_subworld_counts.keys()))
queries = []
File1 = aliased(File)
File2 = aliased(File)
if protocol == 'view1':
if 'world' in groups:
queries.append(default_query().filter(Pair.protocol == 'train'))
if 'dev' in groups:
queries.append(default_query().filter(Pair.protocol == 'test'))
else:
if 'world' in groups:
trainset = self.__world_for__(protocol, subworld)
queries.append(default_query().filter(Pair.protocol.in_(trainset)))
if 'dev' in groups:
devset = self.__dev_for__(protocol)
queries.append(default_query().filter(Pair.protocol.in_(devset)))
if 'eval' in groups:
queries.append(default_query().filter(Pair.protocol == protocol))
retval = []
for query in queries:
if 'matched' not in classes and 'client' not in classes:
query = query.filter(Pair.is_match == False)
if 'unmatched' not in classes and 'impostor' not in classes:
query = query.filter(Pair.is_match == True)
for pair in query:
retval.append(pair)
return retval
[docs] def annotations(self, file, annotation_type=None):
"""Returns the annotations for the given file id as a dictionary, e.g. {'reye':(y,x), 'leye':(y,x)}.
Keyword parameters:
file_id
The ``File`` object for which you want to retrieve the annotations
annotation_type
The type of annotations ('idiap', 'funneled').
If not specified, and if not given in the constructor, all annotations are taken, which might to cause an assertion error.
"""
self.assert_validity()
if annotation_type is None:
annotation_type = self.m_annotation_type
annotation_type = self.check_parameters_for_validity(
annotation_type, "annotation type", self.m_valid_annotation_types)
query = self.query(Annotation).filter(Annotation.annotation_type.in_(
annotation_type)).join(File).filter(File.id == file.id)
assert query.count() == 1
annotation = query.first()
# return the annotations as returned by the call function of the
# Annotation object
return annotation()
[docs] def t_model_ids(self, protocol, groups='dev', **kwargs):
"""Returns the list of model ids used for T-Norm of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.tmodel_ids(protocol=protocol, groups=groups, **kwargs))
[docs] def t_enroll_files(self, protocol, model_id, groups='dev', **kwargs):
"""Returns the list of T-Norm model enrollment File objects from the given model id of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.tobjects(protocol=protocol, groups=groups, model_ids=(model_id,), **kwargs))
[docs] def z_probe_files(self, protocol, groups='dev', **kwargs):
"""Returns the list of Z-Norm probe File objects to probe the model with the given model id of the given protocol for the given group that satisfy your query."""
return self.uniquify(self.zobjects(protocol=protocol, groups=groups, **kwargs))