Source code for facereclib.databases.DatabaseBob

#!/usr/bin/env python
# vim: set fileencoding=utf-8 :
# @author: Manuel Guenther <Manuel.Guenther@idiap.ch>
# @date: Wed Oct  3 10:31:51 CEST 2012
#
# Copyright (C) 2011-2012 Idiap Research Institute, Martigny, Switzerland
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 3 of the License.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.


from .Database import Database, DatabaseZT

[docs]class DatabaseBob (Database): """This class can be used whenever you have a database that follows the default Bob database interface.""" def __init__( self, database, # The bob database that is used all_files_options = {}, # additional options for the database query that can be used to extract all files extractor_training_options = {}, # additional options for the database query that can be used to extract the training files for the extractor training projector_training_options = {}, # additional options for the database query that can be used to extract the training files for the extractor training enroller_training_options = {}, # additional options for the database query that can be used to extract the training files for the extractor training check_original_files_for_existence = False, **kwargs # The default parameters of the base class ): """ Parameters of the constructor of this database: database the bob.db.___ database that provides the actual interface image_directory The directory where the original images are stored. image_extension The file extension of the original images. all_files_options Options passed to the database query used to retrieve all data. extractor_training_options Options passed to the database query used to retrieve the images for the extractor training. projector_training_options Options passed to the database query used to retrieve the images for the projector training. enroller_training_options Options passed to the database query used to retrieve the images for the enroller training. check_original_files_for_existence Enables the test for the original data files when querying the database. kwargs The arguments of the base class """ Database.__init__( self, **kwargs ) self.m_database = database self.original_directory = database.original_directory self.all_files_options = all_files_options self.extractor_training_options = extractor_training_options self.projector_training_options = projector_training_options self.enroller_training_options = enroller_training_options self.check_existence = check_original_files_for_existence self._kwargs = kwargs def __str__(self): """This function returns a string containing all parameters of this class (and its derived class).""" params = ", ".join(["%s=%s" % (key, value) for key, value in self._kwargs.items()]) params += ", original_directory=%s, original_extension=%s" % (self.original_directory, self.original_extension) if self.all_files_options: params += ", all_files_options=%s"%self.all_files_options if self.extractor_training_options: params += ", extractor_training_options=%s"%self.extractor_training_options if self.projector_training_options: params += ", projector_training_options=%s"%self.projector_training_options if self.enroller_training_options: params += ", enroller_training_options=%s"%self.enroller_training_options return "%s(%s)" % (str(self.__class__), params)
[docs] def uses_probe_file_sets(self): """Defines if, for the current protocol, the database uses several probe files to generate a score.""" return self.protocol != 'None' and self.m_database.provides_file_set_for_protocol(self.protocol)
[docs] def all_files(self, groups = None): """Returns all File objects of the database for the current protocol. If the current protocol is 'None' (a string), None (NoneType) will be used instead""" files = self.m_database.objects(protocol = self.protocol if self.protocol != 'None' else None, groups = groups, **self.all_files_options) return self.sort(files)
[docs] def training_files(self, step = None, arrange_by_client = False): """Returns all training File objects of the database for the current protocol.""" if step is None: training_options = self.all_files_options elif step == 'train_extractor': training_options = self.extractor_training_options elif step == 'train_projector': training_options = self.projector_training_options elif step == 'train_enroller': training_options = self.enroller_training_options else: raise ValueError("The given step '%s' must be one of ('train_extractor', 'train_projector', 'train_enroller')" % step) files = self.sort(self.m_database.objects(protocol = self.protocol, groups = 'world', **training_options)) if arrange_by_client: return self.arrange_by_client(files) else: return files
[docs] def test_files(self, groups = ['dev']): """Returns the test files (i.e., enrollment and probe files) for the given groups.""" return self.sort(self.m_database.test_files(protocol = self.protocol, groups = groups, **self.all_files_options))
[docs] def model_ids(self, group = 'dev'): """Returns the model ids for the given group and the current protocol.""" if hasattr(self.m_database, 'model_ids'): return sorted(self.m_database.model_ids(protocol = self.protocol, groups = group)) else: return sorted([model.id for model in self.m_database.models(protocol = self.protocol, groups = group)])
[docs] def client_id_from_model_id(self, model_id, group = 'dev'): """Returns the client id for the given model id.""" if hasattr(self.m_database, 'get_client_id_from_model_id'): return self.m_database.get_client_id_from_model_id(model_id) else: return model_id
[docs] def enroll_files(self, model_id, group = 'dev'): """Returns the list of enrollment File objects for the given model id.""" files = self.m_database.objects(protocol = self.protocol, groups = group, model_ids = (model_id,), purposes = 'enroll', **self.all_files_options) return self.sort(files)
[docs] def probe_files(self, model_id = None, group = 'dev'): """Returns the list of probe File objects (for the given model id, if given).""" if model_id: files = self.m_database.objects(protocol = self.protocol, groups = group, model_ids = (model_id,), purposes = 'probe', **self.all_files_options) else: files = self.m_database.objects(protocol = self.protocol, groups = group, purposes = 'probe', **self.all_files_options) return self.sort(files)
[docs] def probe_file_sets(self, model_id = None, group = 'dev'): """Returns the list of probe File objects (for the given model id, if given).""" if model_id: file_sets = self.m_database.object_sets(protocol = self.protocol, groups = group, model_ids = (model_id,), purposes = 'probe', **self.all_files_options) else: file_sets = self.m_database.object_sets(protocol = self.protocol, groups = group, purposes = 'probe', **self.all_files_options) return self.sort(file_sets)
[docs] def annotations(self, file): """Returns the annotations for the given File object, if available.""" return self.m_database.annotations(file)
[docs] def original_file_names(self, files): """Returns the full path of the original data of the given File objects.""" return self.m_database.original_file_names(files, self.check_existence)
[docs]class DatabaseBobZT (DatabaseBob, DatabaseZT): """This class can be used whenever you have a database that follows the default Bob database interface defining file lists for ZT score normalization.""" def __init__( self, z_probe_options = {}, # Limit the z-probes **kwargs ): # call base class constructor, passing all the parameters to it DatabaseBob.__init__(self, z_probe_options = z_probe_options, **kwargs) self.m_z_probe_options = z_probe_options
[docs] def all_files(self, groups = ['dev']): """Returns all File objects of the database for the current protocol. If the current protocol is 'None' (a string), None (NoneType) will be used instead""" files = self.m_database.objects(protocol = self.protocol if self.protocol != 'None' else None, groups = groups, **self.all_files_options) # add all files that belong to the ZT-norm for group in groups: if group == 'world': continue files += self.m_database.tobjects(protocol = self.protocol if self.protocol != 'None' else None, groups = group, model_ids = None) files += self.m_database.zobjects(protocol = self.protocol if self.protocol != 'None' else None, groups = group, **self.m_z_probe_options) return self.sort(files)
[docs] def t_model_ids(self, group = 'dev'): """Returns the T-Norm model ids for the given group and the current protocol.""" if hasattr(self.m_database, 'tmodel_ids'): return sorted(self.m_database.tmodel_ids(protocol = self.protocol, groups = group)) else: return sorted([model.id for model in self.m_database.tmodels(protocol = self.protocol, groups = group)])
[docs] def t_enroll_files(self, model_id, group = 'dev'): """Returns the list of enrollment File objects for the given T-Norm model id.""" files = self.m_database.tobjects(protocol = self.protocol, groups = group, model_ids = (model_id,)) return self.sort(files)
[docs] def z_probe_files(self, group = 'dev'): """Returns the list of Z-probe File objects.""" files = self.m_database.zobjects(protocol = self.protocol, groups = group, **self.m_z_probe_options) return self.sort(files)
[docs] def z_probe_file_sets(self, group = 'dev'): """Returns the list of Z-probe Fileset objects.""" file_sets = self.m_database.zobject_sets(protocol = self.protocol, groups = group, **self.m_z_probe_options) return self.sort(file_sets)