Source code for chempy.equilibria

# -*- coding: utf-8 -*-
"""
Module collecting classes and functions for dealing with (multiphase) chemical
equilibria.

.. Note::

  This module is provisional at the moment, i.e. the API is not stable and may
  break without a deprecation cycle.

"""
from __future__ import division, absolute_import

import warnings
from collections import defaultdict

import numpy as np

from .chemistry import ReactionSystem, equilibrium_quotient, Equilibrium, Species
from ._util import get_backend
from .util.pyutil import deprecated
from ._eqsys import EqCalcResult, NumSysLin, NumSysLog, NumSysSquare as _NumSysSquare


NumSysSquare = deprecated()(_NumSysSquare)


[docs]class EqSystem(ReactionSystem): _BaseReaction = Equilibrium _BaseSubstance = Species
[docs] def eq_constants(self, non_precip_rids=(), eq_params=None, small=0): if eq_params is None: eq_params = [eq.param for eq in self.rxns] return np.array([small if idx in non_precip_rids else eq_params[idx] for idx, eq in enumerate(eq_params)])
[docs] def upper_conc_bounds(self, init_concs, min_=min, dtype=np.float64): init_concs_arr = self.as_per_substance_array(init_concs, dtype=dtype) composition_conc = defaultdict(float) for conc, s_obj in zip(init_concs_arr, self.substances.values()): for comp_nr, coeff in s_obj.composition.items(): if comp_nr == 0: # charge may be created (if compensated) continue composition_conc[comp_nr] += coeff*conc bounds = [] for s_obj in self.substances.values(): choose_from = [] for comp_nr, coeff in s_obj.composition.items(): if comp_nr == 0: continue choose_from.append(composition_conc[comp_nr]/coeff) bounds.append(min_(choose_from)) return bounds
[docs] def equilibrium_quotients(self, concs): stoichs = self.stoichs() return [equilibrium_quotient(concs, stoichs[ri, :]) for ri in range(self.nr)]
[docs] def stoichs_constants(self, eq_params, rref=False, Matrix=None, backend=None, non_precip_rids=()): if rref: from pyneqsys.symbolic import linear_rref be = get_backend(backend) rA, rb = linear_rref(self.stoichs(non_precip_rids), list(map(be.log, eq_params)), Matrix) return rA.tolist(), list(map(be.exp, rb)) else: return (self.stoichs(non_precip_rids), eq_params)
[docs] def composition_conservation(self, concs, init_concs): composition_vecs, comp_keys = self.composition_balance_vectors() A = np.array(composition_vecs) return (comp_keys, np.dot(A, self.as_per_substance_array(concs).T), np.dot(A, self.as_per_substance_array(init_concs).T))
[docs] def other_phase_species_idxs(self, phase_idx=0): return [idx for idx, s in enumerate( self.substances.values()) if s.phase_idx != phase_idx]
@property @deprecated(last_supported_version='0.3.1', will_be_missing_in='0.5.0', use_instead=other_phase_species_idxs) def precipitate_substance_idxs(self): return [idx for idx, s in enumerate( self.substances.values()) if s.precipitate]
[docs] def phase_transfer_reaction_idxs(self, phase_idx=0): return [idx for idx, rxn in enumerate(self.rxns) if rxn.has_precipitates(self.substances)]
@property @deprecated(last_supported_version='0.3.1', will_be_missing_in='0.5.0', use_instead=phase_transfer_reaction_idxs) def precipitate_rxn_idxs(self): return [idx for idx, rxn in enumerate(self.rxns) if rxn.has_precipitates(self.substances)]
[docs] def dissolved(self, concs): """ Return dissolved concentrations """ new_concs = concs.copy() for r in self.rxns: if r.has_precipitates(self.substances): net_stoich = np.asarray(r.net_stoich(self.substances)) s_net, s_stoich, s_idx = r.precipitate_stoich(self.substances) new_concs -= new_concs[s_idx]/s_stoich * net_stoich return new_concs
def _fw_cond_factory(self, ri, rtol=1e-14): rxn = self.rxns[ri] def fw_cond(x, p): precip_stoich_coeff, precip_idx = rxn.precipitate_stoich(self.substances)[1:3] q = rxn.Q(self.substances, self.dissolved(x)) k = rxn.K() if precip_stoich_coeff > 0: return q*(1+rtol) < k elif precip_stoich_coeff < 0: return q > k*(1+rtol) else: raise NotImplementedError return fw_cond def _bw_cond_factory(self, ri, small): rxn = self.rxns[ri] def bw_cond(x, p): precipitate_idx = rxn.precipitate_stoich(self.substances)[2] if x[precipitate_idx] < small: return False else: return True return bw_cond def _SymbolicSys_from_NumSys(self, NS, conds, rref_equil, rref_preserv): from pyneqsys.symbolic import SymbolicSys import sympy as sp ns = NS(self, backend=sp, rref_equil=rref_equil, rref_preserv=rref_preserv, precipitates=conds) symb_kw = {} if ns.pre_processor is not None: symb_kw['pre_processors'] = [ns.pre_processor] if ns.post_processor is not None: symb_kw['post_processors'] = [ns.post_processor] if ns.internal_x0_cb is not None: symb_kw['internal_x0_cb'] = ns.internal_x0_cb return SymbolicSys.from_callback( ns.f, self.ns, nparams=self.ns + self.nr, **symb_kw)
[docs] def get_neqsys_conditional_chained(self, init_concs, rref_equil=False, rref_preserv=False, NumSys=NumSysLin): from pyneqsys import ConditionalNeqSys, ChainedNeqSys def factory(conds): return ChainedNeqSys([self._SymbolicSys_from_NumSys( NS, conds, rref_equil, rref_preserv) for NS in NumSys]) cond_cbs = [(self._fw_cond_factory(ri), self._bw_cond_factory(ri, NumSys[0].small)) for ri in self.phase_transfer_reaction_idxs()] return ConditionalNeqSys(cond_cbs, factory)
[docs] def get_neqsys_chained_conditional(self, init_concs, rref_equil=False, rref_preserv=False, NumSys=NumSysLin): from pyneqsys import ConditionalNeqSys, ChainedNeqSys def mk_factory(NS): def factory(conds): return self._SymbolicSys_from_NumSys(NS, conds, rref_equil, rref_preserv) return factory return ChainedNeqSys( [ConditionalNeqSys( [(self._fw_cond_factory(ri), self._bw_cond_factory(ri, NS.small)) for ri in self.phase_transfer_reaction_idxs()], mk_factory(NS) ) for NS in NumSys])
[docs] def get_neqsys_static_conditions(self, init_concs, rref_equil=False, rref_preserv=False, NumSys=NumSysLin, precipitates=None): if precipitates is None: precipitates = (False,)*len(self.phase_transfer_reaction_idxs()) from pyneqsys import ChainedNeqSys return ChainedNeqSys([self._SymbolicSys_from_NumSys( NS, precipitates, rref_equil, rref_preserv) for NS in NumSys])
[docs] def get_neqsys(self, neqsys_type, init_concs, NumSys=NumSysLin, **kwargs): new_kw = {'rref_equil': False, 'rref_preserv': False} if neqsys_type == 'static_conditions': new_kw['precipitates'] = None for k in new_kw: if k in kwargs: new_kw[k] = kwargs.pop(k) try: NumSys[0] except TypeError: new_kw['NumSys'] = (NumSys,) else: new_kw['NumSys'] = NumSys return getattr(self, 'get_neqsys_' + neqsys_type)(init_concs, **new_kw)
[docs] def non_precip_rids(self, precipitates): return [idx for idx, precip in zip( self.phase_transfer_reaction_idxs(), precipitates) if not precip]
def _result_is_sane(self, init_concs, x): sc_upper_bounds = np.array(self.upper_conc_bounds(init_concs)) neg_conc, too_much = np.any(x < 0), np.any( x > sc_upper_bounds*(1 + 1e-12)) if neg_conc or too_much: if neg_conc: warnings.warn("Negative concentration") if too_much: warnings.warn("Too much of at least one component") return False return True def _solve(self, init_concs, x0=None, NumSys=(NumSysLog, NumSysLin), neqsys='chained_conditional', **kwargs): if isinstance(neqsys, str): neqsys = self.get_neqsys( neqsys, init_concs, NumSys=NumSys, rref_equil=kwargs.pop('rref_equil', False), rref_preserv=kwargs.pop('rref_preserv', False), precipitates=kwargs.pop('precipitates', None)) if x0 is None: x0 = init_concs params = np.concatenate((init_concs, [float(elem) for elem in self.eq_constants()])) x, sol = neqsys.solve(x0, params, **kwargs) if not sol['success']: warnings.warn("Root-finding indicated as failed by solver.") sane = self._result_is_sane(init_concs, x) return x, sol, sane
[docs] def solve(self, init_concs, varied=None, **kwargs): results = EqCalcResult(self, init_concs, varied) results.solve() return results
[docs] def root(self, init_concs, x0=None, neqsys=None, NumSys=NumSysLog, neqsys_type='chained_conditional', **kwargs): init_concs = self.as_per_substance_array(init_concs) params = np.concatenate((init_concs, [float(elem) for elem in self.eq_constants()])) if neqsys is None: neqsys = self.get_neqsys( neqsys_type, init_concs, NumSys=NumSys, rref_equil=kwargs.pop('rref_equil', False), rref_preserv=kwargs.pop('rref_preserv', False), precipitates=kwargs.pop('precipitates', None)) if x0 is None: x0 = init_concs x, sol = neqsys.solve(x0, params, **kwargs) if not sol['success']: warnings.warn("Root finding indicated as failed by solver.") sane = self._result_is_sane(init_concs, x) return x, sol, sane
@staticmethod def _get_default_plot_ax(subplot_kwargs=None): import matplotlib.pyplot as plt if subplot_kwargs is None: subplot_kwargs = dict(xscale='log', yscale='log') return plt.subplot(1, 1, 1, **subplot_kwargs)
[docs] def substance_labels(self, latex=False): if latex: result = ['$' + s.latex_name + '$' for s in self.substances.values()] return result else: return [s.name for s in self.substances.values()]
[docs] def roots(self, init_concs, varied_data, varied, x0=None, NumSys=NumSysLog, plot_kwargs=None, neqsys_type='chained_conditional', **kwargs): """ Parameters ---------- init_concs : array or dict varied_data : array varied_idx : int or str x0 : array NumSys : _NumSys subclass See :class:`NumSysLin`, :class:`NumSysLog`, etc. plot_kwargs : dict See py:meth:`pyneqsys.NeqSys.solve`. Two additional keys are intercepted here: latex_names: bool (default: False) conc_unit_str: str (default: 'M') neqsys_type : str what method to use for NeqSys construction (get_neqsys_*) \*\*kwargs : kwargs passed on to py:meth:`pyneqsys.NeqSys.solve_series` """ _plot = plot_kwargs is not None if _plot: latex_names = plot_kwargs.pop('latex_names', False) conc_unit_str = plot_kwargs.pop('conc_unit_str', 'M') if 'ax' not in plot_kwargs: plot_kwargs['ax'] = self._get_default_plot_ax() init_concs = self.as_per_substance_array(init_concs) neqsys = self.get_neqsys( neqsys_type, init_concs, NumSys=NumSys, rref_equil=kwargs.pop('rref_equil', False), rref_preserv=kwargs.pop('rref_preserv', False), precipitates=kwargs.pop('precipitates', None)) if x0 is None: x0 = init_concs if _plot: cb = neqsys.solve_and_plot_series if 'plot_kwargs' not in kwargs: kwargs['plot_kwargs'] = {} if 'labels' not in kwargs['plot_kwargs']: kwargs['plot_kwargs']['labels'] = ( self.substance_labels(latex_names)) if 'substances' in plot_kwargs: if 'indices' in plot_kwargs: raise ValueError("Now I am confused..") kwargs['plot_kwargs']['indices'] = map( self.as_substance_index, plot_kwargs.pop('substances')) print(kwargs['plot_kwargs']['indices']) else: cb = neqsys.solve_series params = np.concatenate((init_concs, self.eq_constants())) xvecs, info_dicts = cb( x0, params, varied_data, self.as_substance_index(varied), propagate=False, **kwargs) sanity = [self._result_is_sane(init_concs, x) for x in xvecs] if _plot: import matplotlib.pyplot as plt from pyneqsys.plotting import mpl_outside_legend mpl_outside_legend(plt.gca()) varied_subst = self.substances[varied] xlbl = ('$[' + varied_subst.latex_name + ']_0$' if latex_names else str(varied_subst)) plt.gca().set_xlabel(xlbl + ' / ' + conc_unit_str) plt.gca().set_ylabel('Concentration / ' + conc_unit_str) return xvecs, info_dicts, sanity
[docs] def plot_errors(self, concs, init_concs, varied_data, varied, axes=None, compositions=True, Q=True, subplot_kwargs=None): if axes is None: import matplotlib.pyplot as plt if subplot_kwargs is None: subplot_kwargs = dict(xscale='log') fig, axes = plt.subplots(1, 2, figsize=(10, 4), subplot_kw=subplot_kwargs) varied_idx = self.as_substance_index(varied) ls, c = '- -- : -.'.split(), 'krgbcmy' all_inits = np.tile(self.as_per_substance_array(init_concs), (len(varied_data), 1)) all_inits[:, varied_idx] = varied_data if compositions: cmp_nrs, m1, m2 = self.composition_conservation(concs, all_inits) for cidx, (cmp_nr, a1, a2) in enumerate(zip(cmp_nrs, m1, m2)): axes[0].plot(concs[:, varied_idx], a1-a2, label='Comp ' + str(cmp_nr), ls=ls[cidx % len(ls)], c=c[cidx % len(c)]) axes[1].plot(concs[:, varied_idx], (a1-a2)/np.abs(a2), label='Comp ' + str(cmp_nr), ls=ls[cidx % len(ls)], c=c[cidx % len(c)]) if Q: # TODO: handle precipitate phases in plotting Q error qs = self.equilibrium_quotients(concs) ks = [rxn.param for rxn in self.rxns] for idx, (q, k) in enumerate(zip(qs, ks)): axes[0].plot(concs[:, varied_idx], q-k, label='K R:' + str(idx), ls=ls[(idx+cidx) % len(ls)], c=c[(idx+cidx) % len(c)]) axes[1].plot(concs[:, varied_idx], (q-k)/k, label='K R:' + str(idx), ls=ls[(idx+cidx) % len(ls)], c=c[(idx+cidx) % len(c)]) from pyneqsys.plotting import mpl_outside_legend mpl_outside_legend(axes[0]) mpl_outside_legend(axes[1]) axes[0].set_title("Absolute errors") axes[1].set_title("Relative errors")