Source code for flask_mab.storage
"""
Defines various storage engines for the MAB
interface
"""
import json
import flask_mab.bandits
[docs]class BanditEncoder(json.JSONEncoder):
"""Json serializer for Bandits"""
[docs] def default(self, obj):
if isinstance(obj, flask_mab.bandits.Bandit):
dict_repr = obj.__dict__
dict_repr['bandit_type'] = obj.__class__.__name__
return dict_repr
return json.JSONEncoder.default(self, obj)
[docs]class BanditDecoder(json.JSONDecoder):
"""Json Marshaller for Bandits"""
[docs] def decode(self, obj):
dict_repr = json.loads(obj)
for key in dict_repr.keys():
if 'bandit_type' not in dict_repr[key].keys():
raise TypeError("Serialized object is not a valid bandit")
dict_repr[key] = flask_mab.bandits.Bandit.fromdict(dict_repr[key])
return dict_repr
[docs]class BanditStorage(object):
"""The base interface for a storage engine, implements no-ops for tests
"""
[docs] def save(self, bandits):
pass
[docs] def load(self):
return {}
[docs]class JSONBanditStorage(BanditStorage):
"""Json based file storage
Saves to local file
"""
def __init__(self, filepath):
self.file_handle = filepath
[docs] def flush(self):
open(self.file_handle, 'w').truncate()
[docs] def save(self, bandits):
json_bandits = json.dumps(bandits, indent=4, cls=BanditEncoder)
open(self.file_handle, 'w').write(json_bandits)
[docs] def load(self):
try:
with open(self.file_handle, 'r') as bandit_file:
bandits = bandit_file.read()
return json.loads(bandits, cls=BanditDecoder)
except (ValueError, IOError):
return {}