Source code for pysym.core

# -*- coding: utf-8 -*-

from __future__ import (absolute_import, division, print_function)

import functools
import itertools
import math
import operator
import warnings
import weakref


def _wrap_numbers(func):
    @functools.wraps(func)
    def f(*args):
        new_args = tuple(map(Number.make, args))
        return func(*new_args)
    return f


class _deprecated(object):
    def __init__(self, msg):
        self.msg = msg

    def __call__(self, func):
        @functools.wraps(func)
        def f(*args, **kwargs):
            warnings.warn('Deprecation warning: ' + self.msg)
            return func(*args, **kwargs)
        return f


[docs]def collect(sorted_args, collect_to): nargs = len(sorted_args) if nargs <= 1: return sorted_args prev = sorted_args[0] count = 1 new_args = [] def add(arg, count): if count == 1: new_args.append(arg) elif count > 1: new_args.append(collect_to.create( (arg, Number(count)) )) for idx, arg in enumerate(sorted_args[1:], 1): is_last = (idx == (nargs - 1)) if arg is prev: count += 1 if is_last: add(arg, count) else: continue else: add(prev, count) if is_last: add(arg, 1) count = 1 prev = arg return tuple(new_args)
[docs]def merge(args, mrg_cls=None): if mrg_cls is None: return args new_args = [] merged = False for arg in args: if isinstance(arg, mrg_cls): new_args.extend(arg.args) merged = True else: new_args.append(arg) if merged: return merge(new_args, mrg_cls) else: return new_args
[docs]def merge_drop_sort_collect(args, collect_to, drop=(), mrg_cls=None): return collect( sorted(filter( lambda x: not x.found_in(drop), merge(sorted(args), mrg_cls) )), collect_to)
@functools.total_ordering
[docs]class Basic(object): __slots__ = ('args',)
[docs] def is_atomic(self): return False
[docs] def is_zero(self): return False
def __init__(self, *args): self.args = args @classmethod
[docs] def create(cls, args): return cls(*args) # extra magic allowed
def __hash__(self): return hash(self.args) def __lt__(self, other): other = Number.make(other) typ1, typ2 = type(self), type(other) if typ1 is typ2: if self.is_atomic(): return self.args[0] < other.args[0] else: for a1, a2 in zip(self.args, other.args): if a1 < a2: return True return False else: return typ1.__name__ < typ2.__name__ def __eq__(self, other): other = Number.make(other) typ1, typ2 = type(self), type(other) if typ1 is typ2: if self.is_atomic(): return self.args[0] == other.args[0] else: for a1, a2 in zip(self.args, other.args): if not a1 == a2: return False return True else: return False def __repr__(self): return '%s(%s)' % (self.__class__.__name__, ', '.join( repr(arg) for arg in self.args)) def _print_ccode(self): return str(self)
[docs] def has(self, instance): for arg in self.args: if arg.has(instance): return True return False
[docs] def found_in(self, flat_iterable): for elem_key in flat_iterable: if self == elem_key: return True return False
def _subs(self, symb, repl): if self.has(symb): if symb is self: raise ValueError("Impossible, bug!") else: return self.create(tuple( repl if arg is symb else arg._subs(symb, repl) for arg in self.args )) else: return self
[docs] def subs(self, subs_dict): result = self for key, val in subs_dict.items(): result = result._subs(key, val) return result
[docs] def expand(self): return self.create(tuple(arg.expand() for arg in self.args))
@_wrap_numbers def __add__(self, other): return Add.create((self, other)) def __radd__(self, other): return self+other @_wrap_numbers def __mul__(self, other): return Mul.create((self, other)) def __rmul__(self, other): return self*other @_wrap_numbers def __pow__(base, exponent): return Pow.create((base, exponent)) @_wrap_numbers def __truediv__(num, denom): return Fraction.create((num, denom)) @_wrap_numbers def __rtruediv__(denom, num): return Fraction.create((num, denom)) @_deprecated('Use "from __future__ import division"') @_wrap_numbers def __div__(num, denom): return Fraction.create((num, denom)) @_deprecated('Use "from __future__ import division"') @_wrap_numbers def __rdiv__(denom, num): return Fraction.create((num, denom)) @_wrap_numbers def __sub__(self, other): return Sub.create((self, other)) @_wrap_numbers def __rsub__(self, other): return Sub.create((other, self)) def __neg__(self): return -One * self
[docs]class Relational(Basic): _rel_op = None _rel_op_str = None
[docs] def evalb(self): return self._rel_op(*self.args)
def __str__(self): return self._rel_op_str % self.args
[docs]class Eq(Relational): _rel_op = operator.__eq__ _rel_op_str = '(%s == %s)'
[docs]class Ne(Relational): _rel_op = operator.__ne__ _rel_op_str = '(%s != %s)'
[docs]class Lt(Relational): _rel_op = operator.__lt__ _rel_op_str = '(%s < %s)'
[docs]class Le(Relational): _rel_op = operator.__le__ _rel_op_str = '(%s <= %s)'
[docs]class Gt(Relational): _rel_op = operator.__gt__ _rel_op_str = '(%s > %s)'
[docs]class Ge(Relational): _rel_op = operator.__ge__ _rel_op_str = '(%s >= %s)'
[docs]class Not(Relational): _rel_op = operator.__not__ _rel_op_str = '(not %s)'
[docs]class Atomic(Basic): __all_instances = weakref.WeakValueDictionary() __slots__ = ('args', '__all_Atomic_instances',)
[docs] def is_atomic(self): return True
def __new__(cls, arg): instance = Atomic.__all_instances.get(arg, None) if instance is None: instance = object.__new__(cls) instance.args = (arg,) Atomic.__all_instances[arg] = instance return instance
[docs] def has(self, instance): if instance is self: return True
[docs] def found_in(self, flat_iterable): for elem in flat_iterable: if elem is self: return True return False
def _subs(self, symb, repl): return repl if self is symb else self
[docs] def expand(self): return self
[docs]class Number(Atomic): _NUMBER_TYPES = (int, float) def __abs__(self): return -self if self < 0 else self def __hash__(self): return hash(self.args[0])
[docs] def is_zero(self): return self.args[0] == 0
@classmethod
[docs] def make(cls, arg): if isinstance(arg, cls._NUMBER_TYPES): return cls(arg) if hasattr(arg, 'dtype'): # NumPy object return cls(arg) return arg
[docs] def diff(self, wrt): return Zero
[docs] def evalf(self): arg = self.args[0] if isinstance(arg, self._NUMBER_TYPES): return arg else: return float(arg)
def __neg__(self): return Number(-self.args[0]) def __str__(self): return str(self.args[0]) def _print_ccode(self): return str(float(self.args[0])) # integer division
[docs]class Symbol(Atomic):
[docs] def diff(self, instance): if instance is self: return One else: return Zero
def __str__(self): return str(self.args[0])
[docs]def Dummy(): Dummy.counter -= 1 return Symbol('Dummy' + str(Dummy.counter - 1))
Dummy.counter = 1 Zero = Number(0) One = Number(1) Two = Number(2) nan = Number(float('nan'))
[docs]class Operator(Basic): _operator = None _op_str = None _op_cstr = None # C-code _commutative = True
[docs] def evalf(self): return self._operator(*tuple(arg.evalf() for arg in self.args))
def __str__(self): return self._op_str % self.args def _print_ccode(self): return (self._op_cstr or self._op_str) % tuple( arg._print_ccode() for arg in self.args)
[docs] def sorted(self): if self._commutative: return self.create(sorted(self.args)) else: return self
[docs]class Reduction(Operator): @classmethod
[docs] def create(cls, args): if len(args) == 1: return args[0] else: return super(Reduction, cls).create(args)
[docs] def evalf(self): return functools.reduce(self._operator, ( arg.evalf() for arg in self.args))
def __str__(self): return '(' + self._op_str.join(map(str, self.args)) + ')' def _print_ccode(self): return '(' + self._op_str.join( arg._print_ccode() for arg in self.args) + ')'
[docs]class Add(Reduction): _operator = operator.add _op_str = ' + ' @classmethod
[docs] def create(cls, args): args = tuple(filter(lambda x: x is not Zero, args)) if len(args) == 0: return Zero else: return super(Add, cls).create( merge_drop_sort_collect(args, Mul, (Zero, Mul(Zero)), Add))
[docs] def diff(self, wrt): return self.create(tuple(arg.diff(wrt) for arg in self.args))
[docs] def evalf(self): if len(self.args) == 0: return Zero else: return super(Add, self).evalf()
[docs] def insert_mult(self, factor): return self.create(tuple(Mul.create((arg, factor)) for arg in self.args))
[docs]class Mul(Reduction): _operator = operator.mul _op_str = '*' @classmethod
[docs] def create(cls, args): if len(args) == 0: return One else: if Zero.found_in(args): return Zero else: return super(Mul, cls).create( merge_drop_sort_collect(args, Pow, (One,), Mul))
[docs] def diff(self, wrt): return Add.create(tuple( Mul.create(tuple( arg.diff(wrt) if i == idx else arg for i, arg in enumerate(self.args))) for idx in range(len(self.args))))
[docs] def expand(self): for idx, arg in enumerate(self.args): if isinstance(arg, Add): if idx == 0: # use of `create` guarantees len(args) > 1 return arg.insert_mult(Mul.create( self.args[idx+1:])).expand() if idx > 0: # absorb into first Add summation = arg.insert_mult(Mul.create(self.args[:idx])) return Mul.create( (summation,) + self.args[idx + 1:] ).expand() return self
[docs]class Binary(Operator): def __init__(self, a, b): super(Binary, self).__init__(a, b)
[docs]class Sub(Binary): _operator = operator.sub _commutative = False _op_str = '(%s - %s)' @classmethod
[docs] def create(cls, args): a, b = args # a - b if a.is_zero(): return -b if b.is_zero(): return a if a == b: return Zero if isinstance(a, Number) and isinstance(b, Number): return Number.make(a.args[0] - b.args[0]) return cls(*args)
[docs] def diff(self, wrt): return Sub.create((self.args[0].diff(wrt), self.args[1].diff(wrt)))
[docs]class Fraction(Binary): _operator = operator.truediv _commutative = False _op_str = '(%s/%s)' @classmethod
[docs] def create(cls, args): instance = cls(*args) if instance.args[1].is_zero(): raise ZeroDivisionError else: if instance.args[0].is_zero(): return Zero else: return instance
[docs] def evalf(self): return self.args[0].evalf() / self.args[1].evalf()
[docs] def diff(self, wrt): a, b = self.args # a/b return self.create(( Sub.create(( a.diff(wrt)*b, Mul.create((a, b.diff(wrt))) )), Pow.create((b, Two)) )) # return (self.args[0] * self.args[1]**-One).diff(wrt)
[docs]class Pow(Binary): _operator = operator.pow _commutative = False _op_str = '(%s**%s)' # factorial has higher precedence (hence parenthesis) _op_cstr = 'pow(%s, %s)'
[docs] def evalf(self): return self.args[0].evalf() ** self.args[1].evalf()
[docs] def diff(self, wrt): base, exponent = self.args in_base = base.has(wrt) in_exponent = exponent.has(wrt) if in_base: if in_exponent: pass else: return Mul.create(( exponent, Pow.create(( base, Sub.create(( exponent, One )) )), base.diff(wrt) )) else: if in_exponent: pass else: return Zero return exp.create(( Mul.create(( log.create(( base, )), exponent )), )).diff(wrt) # exponent *= log(base) # if exponent.has(wrt): # return Mul.create((exp(exponent), exponent.diff(wrt))) # else: # return Zero
@classmethod
[docs] def create(cls, args): base, exponent = args if exponent.is_zero(): return One elif exponent == One: return base if base.is_zero(): return Zero return cls(*args)
[docs]class ITE(Basic): def __init__(self, cond, if_true, if_false): self.args = (cond, if_true, if_false) def _eval(self): return self.args[1] if self.args[0].evalb() else self.args[2]
[docs] def evalf(self): return self._eval().evalf()
def __str__(self): return '({1} if {0} else {2})'.format(*self.args) def _print_ccode(self): return '((%s) ? %s : %s)' % (arg._print_ccode() for arg in self.args)
[docs]class Function(Basic): _function = None _func_str = None
[docs] def evalf(self): return self._function(*tuple(arg.evalf() for arg in self.args))
def __str__(self): return (self._func_str or str(self._function)) + '(' + ', '.join( map(str, self.args)) + ')'
[docs]class Function1(Function): @_wrap_numbers def __init__(self, arg): self.args = (arg,) @staticmethod def _deriv(arg): raise NotImplementedError
[docs] def diff(self, wrt): return Mul.create((self._deriv(self.args[0]), self.args[0].diff(wrt)))
[docs]class gamma(Function1): _function = math.gamma _func_str = 'gamma'
[docs]class Abs(Function1): _function = abs _func_str = 'Abs' @staticmethod def _deriv(arg): return ITE(Lt(arg, Zero), -One, ITE(Gt(arg, Zero), One, nan))
[docs]class exp(Function1): _function = math.exp _func_str = 'exp' @staticmethod def _deriv(arg): return exp(arg)
[docs]class sqrt(Function1): _function = math.sqrt _func_str = 'sqrt' @staticmethod def _deriv(arg): return 1/(2*sqrt(arg))
[docs]class log(Function1): _function = math.log _func_str = 'log' @staticmethod def _deriv(arg): return Pow(arg, -One)
[docs]class sin(Function1): _function = math.sin _func_str = 'sin' @staticmethod def _deriv(arg): return cos(arg)
[docs]class cos(Function1): _function = math.cos _func_str = 'sin' @staticmethod def _deriv(arg): return -sin(arg)
[docs]class tan(Function1): _function = math.tan _func_str = 'tan' @staticmethod def _deriv(arg): return One + tan(arg)**Two
[docs]class asin(Function1): _function = math.asin _func_str = 'asin' @staticmethod def _deriv(arg): return One/(One - arg**Two)**(One/Two)
[docs]class acos(Function1): _function = math.acos _func_str = 'acos' @staticmethod def _deriv(arg): return -asin._deriv(arg)
[docs]class atan(Function1): _function = math.atan _func_str = 'atan' @staticmethod def _deriv(arg): return One/(One + arg**Two)
[docs]class Vector(Basic): def __init__(self, *args): self.args = args def __len__(self): return len(self.args)
[docs] def diff(self, wrt): return self.__class__(tuple(arg.diff(wrt) for arg in self.args))
def __iter__(self): return iter(self.args) def __getitem__(self, key): return self.args[key]
[docs]class Matrix(Basic): def __init__(self, nrows, ncols, source): if callable(source): callback = source else: def callback(ri, ci): try: return Number.make(source[ri, ci]) except (TypeError, IndexError): return Number.make(source[ri*ncols + ci]) self.args = (nrows, ncols) + tuple( callback(ri, ci) for ri, ci in itertools.product( range(nrows), range(ncols)) ) def _subs(self, symb, repl): return self.__class__(self.nrows, self.ncols, self.flatten()._subs(symb, repl)) def __iter__(self): if self.shape[0] == 1: for ci in range(self.ncols): yield self[0, ci] elif self.shape[1] == 1: for ri in range(self.nrows): yield self[ri, 0] else: for ri in range(self.nrows): yield self.__class__(1, self.ncols, lambda i, j: self[ri, j]) @property def nrows(self): return self.args[0] @property def ncols(self): return self.args[1] @property def shape(self): return (self.nrows, self.ncols)
[docs] def flatten(self): return Vector(*tuple(self[ri, ci] for ri, ci in itertools.product( range(self.nrows), range(self.ncols))))
def _get_element(self, idx): return self.args[2+idx] def __getitem__(self, key): ri, ci = key return self._get_element(self.ncols*ri + ci)
[docs] def jacobian(self, iterable): try: shape = iterable.shape if len(shape) > 2: raise ValueError except AttributeError: iterable = tuple(iterable) else: if len(shape) == 2: if shape[0] != 1 and shape[1] != 1: raise ValueError('need column or row vector') if shape[0] == 1: iterable = tuple(iterable[0, i] for i in range(shape[1])) else: iterable = tuple(iterable[i, 0] for i in range(shape[0])) if self.ncols != 1 and self.nrows != 1: raise TypeError('jacobian only defined for row or column matrices') if self.ncols == 1: return self.__class__( self.nrows, len(iterable), lambda ri, ci: self._get_element(ri).diff(iterable[ci])) elif self.nrows == 1: return self.__class__( max(self.shape), len(iterable), lambda ri, ci: self._get_element(ri).diff(iterable[ci]))
[docs] def evalf(self): return [[self[ri, ci].evalf() for ci in range(self.ncols)] for ri in range(self.nrows)]