Source code for flask_mab.storage

"""
Defines various storage engines for the MAB
interface
"""

import json
import flask_mab.bandits

[docs]class BanditEncoder(json.JSONEncoder): """Json serializer for Bandits"""
[docs] def default(self, obj): if isinstance(obj, flask_mab.bandits.Bandit): dict_repr = obj.__dict__ dict_repr['bandit_type'] = obj.__class__.__name__ return dict_repr return json.JSONEncoder.default(self, obj)
[docs]class BanditDecoder(json.JSONDecoder): """Json Marshaller for Bandits"""
[docs] def decode(self, obj): dict_repr = json.loads(obj) for key in dict_repr.keys(): if 'bandit_type' not in dict_repr[key].keys(): raise TypeError("Serialized object is not a valid bandit") dict_repr[key] = flask_mab.bandits.Bandit.fromdict(dict_repr[key]) return dict_repr
[docs]class BanditStorage(object): """The base interface for a storage engine, implements no-ops for tests """
[docs] def flush(self): pass
[docs] def save(self, bandits): pass
[docs] def load(self): return {}
[docs]class JSONBanditStorage(BanditStorage): """Json based file storage Saves to local file """ def __init__(self, filepath): self.file_handle = filepath
[docs] def flush(self): open(self.file_handle, 'w').truncate()
[docs] def save(self, bandits): json_bandits = json.dumps(bandits, indent=4, cls=BanditEncoder) open(self.file_handle, 'w').write(json_bandits)
[docs] def load(self): try: with open(self.file_handle, 'r') as bandit_file: bandits = bandit_file.read() return json.loads(bandits, cls=BanditDecoder) except (ValueError, IOError): return {}
Fork me on GitHub