Source code for piecash.sa_extra

from __future__ import division, unicode_literals
from __future__ import print_function

import datetime
import logging
import sys
import unicodedata

import pytz
import tzlocal
from sqlalchemy import types, Table, MetaData, ForeignKeyConstraint, event, create_engine
from sqlalchemy.dialects import sqlite
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.declarative import as_declarative
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import sessionmaker, object_session

# import yaml

if sys.version > '3':
    long = int
else:
    long = long


def __init__blocked(self, *args, **kwargs):
    raise NotImplementedError("Objects of type {} cannot be created from scratch "
                              "(only read)".format(self.__class__.__name__))


@as_declarative(constructor=__init__blocked)
class DeclarativeBase(object):
    @property
    def book(self):
        """Return the gnc book holding the object
        """
        s = object_session(self)
        return s and s.book

    def object_to_validate(self, change):
        """yield the objects to validate when the object is modified (change="new" "deleted" or "dirty").

        For instance, if the object is a Split, if it changes, we want to revalidate not the split
        but its transaction and its lot (if any). split.object_to_validate should yeild both split.transaction
        and split.lot
        """
        return
        yield

    def validate(self):
        """This must be reimplemented for object requiring validation
        """
        raise NotImplementedError(self)

    def get_all_changes(self):
        try:
            return self.book.session._all_changes[id(self)]
        except KeyError:
            return {"STATE_CHANGES": ["unchanged"],
                    "OBJECT": self}

    if sys.version > '3':
        def __str__(self):
            return self.__unirepr__()

        def __repr__(self):
            return self.__unirepr__()

    else:
        def __str__(self):
            return unicodedata.normalize('NFKD', self.__unirepr__()).encode('ascii', 'ignore')

        def __repr__(self):
            return self.__unirepr__().encode('ascii', errors='backslashreplace')

    def __unicode__(self):
        return self.__unirepr__()


tz = tzlocal.get_localzone()
utc = pytz.utc


@compiles(sqlite.DATE, 'sqlite')
def compile_date(element, compiler, **kw):
    return "TEXT(8)"  # % element.__class__.__name__


@compiles(sqlite.DATETIME, 'sqlite')
def compile_date(element, compiler, **kw):
    return "TEXT(14)"  # % element.__class__.__name__


class _DateTime(types.TypeDecorator):
    """Used to customise the DateTime type for sqlite (ie without the separators as in gnucash
    """
    impl = types.TypeEngine

    def load_dialect_impl(self, dialect):
        if dialect.name == "sqlite":
            return sqlite.DATETIME(
                storage_format="%(year)04d%(month)02d%(day)02d%(hour)02d%(minute)02d%(second)02d",
                regexp=r"(\d{4})(\d{2})(\d{2})(\d{2})(\d{2})(\d{2})",
            )
        else:
            return types.DateTime()

    def process_bind_param(self, value, engine):
        if value is not None:
            assert isinstance(value, datetime.datetime), "value {} is not of type datetime.datetime but type {}".format(
                value, type(value))
            if value.tzinfo is None:
                value = tz.localize(value)
            if value.microsecond != 0:
                logging.warning("A datetime has been given with microseconds which are not saved in the database")
            return value.astimezone(utc)

    def process_result_value(self, value, engine):
        if value is not None:
            return utc.localize(value).astimezone(tz)


class _Date(types.TypeDecorator):
    """Used to customise the DateTime type for sqlite (ie without the separators as in gnucash
    """
    impl = types.TypeEngine
    is_sqlite = False

    def load_dialect_impl(self, dialect):
        if dialect.name == "sqlite":
            return sqlite.DATE(
                storage_format="%(year)04d%(month)02d%(day)02d",
                regexp=r"(\d{4})(\d{2})(\d{2})"
            )
        else:
            return types.Date()


[docs]def mapped_to_slot_property(col, slot_name, slot_transform=lambda x: x): """Assume the attribute in the class as the same name as the table column with "_" prepended""" col_name = "_{}".format(col.name) def fget(self): return getattr(self, col_name) def fset(self, value): v = slot_transform(value) if v is None: if slot_name in self: del self[slot_name] else: self[slot_name] = v setattr(self, col_name, value) def expr(cls): return col return hybrid_property( fget=fget, fset=fset, expr=expr, )
[docs]def pure_slot_property(slot_name, slot_transform=lambda x: x): """ Create a property (class must have slots) that maps to a slot :param slot_name: name of the slot :param slot_transform: transformation to operate before assigning value :return: """ def fget(self): # return None if the slot does not exist. alternative could be to raise an exception try: return self[slot_name].value except KeyError: return None def fset(self, value): v = slot_transform(value) if v is None: if slot_name in self: del self[slot_name] else: self[slot_name] = v return hybrid_property( fget=fget, fset=fset, )
def kvp_attribute(name, to_gnc, from_gnc, default=None): def getter(self): try: return from_gnc(self[name].value) except KeyError: return default def setter(self, value): if value == default: try: del self[name] except KeyError: pass else: self[name] = to_gnc(value) return property(getter, setter)
[docs]def get_foreign_keys(metadata, engine): """ Retrieve all foreign keys from metadata bound to an engine :param metadata: :param engine: :return: """ reflected_metadata = MetaData() for table_name in list(metadata.tables.keys()): table = Table( table_name, reflected_metadata, autoload=True, autoload_with=engine, ) for constraint in table.constraints: if not isinstance(constraint, ForeignKeyConstraint): continue yield constraint
Session = sessionmaker(autoflush=True) def create_piecash_engine(uri_conn, **kwargs): eng = create_engine(uri_conn, **kwargs) if eng.name == "sqlite": # add proper isolation code for sqlite engine @event.listens_for(eng, "connect") def do_connect(dbapi_connection, connection_record): # disable pysqlite's emitting of the BEGIN statement entirely. # also stops it from emitting COMMIT before any DDL. # print("=========================== in DO CONNECT") # dbapi_connection.isolation_level = "IMMEDIATE" # dbapi_connection.isolation_level = "EXCLUSIVE" pass @event.listens_for(eng, "begin") def do_begin(conn): # emit our own BEGIN # print("=========================== in DO BEGIN") # conn.execute("BEGIN EXCLUSIVE") pass return eng class ChoiceType(types.TypeDecorator): impl = types.INTEGER() def __init__(self, choices, **kw): self.choices = dict(choices) super(ChoiceType, self).__init__(**kw) def process_bind_param(self, value, dialect): try: return [k for k, v in self.choices.items() if v == value][0] except IndexError: # print("Value '{}' is not in [{}]".format(", ".join(self.choices.values()))) raise ValueError("Value '{}' is not in choices [{}]".format(value, ", ".join(self.choices.values()))) def process_result_value(self, value, dialect): return self.choices[value]