Source code for finesse.symbols

"""Symbolic manipulations (`expand`, `collect`, etc.) are based on the book:

Cohen JS. Computer algebra and symbolic computation: mathematical methods. First
edition, 2003.
"""

import abc
import logging
import operator
import warnings
from contextlib import contextmanager
from functools import reduce
from numbers import Number

import numpy as np
from packaging.version import parse

from collections import defaultdict

from finesse import constants
from finesse.exceptions import FinesseException
from finesse.utilities import is_iterable
from finesse.utilities import OrderedSet
import finesse.cymath.ufuncs as ufuncs

LOGGER = logging.getLogger(__name__)

MAKE_LOP = lambda name, opfn: lambda self, other: Function(name, opfn, self, other)
MAKE_ROP = lambda name, opfn: lambda self, other: Function(name, opfn, other, self)

# Default no simplification happens, must wrap symbol code that needs it
# with the context manager below. When flagged symbols will attempt to simplify
# themselves, using a standardised ordering of arguments. Only +, *, and ** operators
# are used in simplified symbols which allow various symbolic simplification
# algorithms to be used.
_SIMPLIFICATION_ENABLED = False  # Global flag to state if simplification will happen


[docs]@contextmanager def simplification(allow_flagged=False): """When used any symbolic operations will apply various simplification rules rather than recording everything symbolic operation, to preserve intent. This is useful when you want situations like 0*a -> 0, or a*a -> a**2. A complete simplification is not applied but it will generally yeild more efficient symbolic expressions. Intent preservation is required by KatScript so that it can serialise and deserialise (unparse and parse) a model into a script form without losing specific equations. For example, it often useful to record how many minus signs or factors of two have been used in an expression, rather than cancelling them out for record keeping. Parameters ---------- allow_flagged : bool, optional When True, it will not throw an error if already in a simplification state. """ global _SIMPLIFICATION_ENABLED if allow_flagged and _SIMPLIFICATION_ENABLED: yield else: if _SIMPLIFICATION_ENABLED is True: raise RuntimeError("Simplification has already been enabled") try: _SIMPLIFICATION_ENABLED = True yield finally: _SIMPLIFICATION_ENABLED = False
[docs]def base_exponent(y): try: if y.op is operator.pow: return y.args[0], y.args[1] else: return y, Constant(1) except AttributeError: return y, Constant(1)
[docs]def operator_add(*args): return reduce(operator.add, args)
[docs]def operator_mul(*args): return reduce(operator.mul, args)
[docs]def operator_sub(*args): return reduce(operator.sub, args)
[docs]def add_sort_key(a): try: # chr(58) is the character after 9 so # numerical numbers come first. if a.op is operator_add: return chr(58) elif a.op is operator.pow: return chr(58) + str(a.args[0]) + str(a.args[1]) else: return str(a) except AttributeError: if hasattr(a, "name"): return a.name else: return str(a)
[docs]def MAKE_simplify_add(dir): def simplify_add(self, other): if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("+", operator.add, self, other) else: return Function("+", operator.add, other, self) if isinstance(self, Number) and np.all(self == 0): return other elif isinstance(other, Number) and np.all(other == 0): return self elif type(self) is Constant and type(other) is Constant: if self.is_named and other.is_named: if self == other: return 2 * self else: return self + other elif self.is_named ^ other.is_named: if dir == "LOP": return Function("+", operator_add, self, other) else: return Function("+", operator_add, other, self) else: return Constant(self.value + other.value) else: def process(a): try: if a.op is operator_add: a = a.args else: a = (a,) except AttributeError: a = (a,) return a a = process(self) b = process(other) args = [*a, *b] # Collect all the terms together. Whilst `collect` could be used # it has the potential for infinite loops and extra steps not needed terms = defaultdict(list) # each of the terms and its multiplier for x in args: c, t = coefficient_and_term(x) terms[t].append(c) collected_args = [] for k, v in terms.items(): if k is None: # constants # append indvidually and avoid using `sum` to stop infinite loops # don't include zeros in simplification for k in v: if k != 0: collected_args.append(k) else: s = sum(v) if s != 0: collected_args.append(s * k) if len(collected_args) == 0: rtn = Constant(0) elif len(collected_args) == 1: rtn = collected_args[0] else: collected_args.sort(key=add_sort_key) rtn = Function("+", operator_add, *collected_args) return rtn return simplify_add
[docs]def mul_sort_key(a): "Sorting key for multiplication arguments. Puts constants first then others" if isinstance(a, (Number, Constant)): return chr(0) + str(a) else: try: if a.op is operator_add: return str(a) elif a.op is operator.pow: return str(a.args[0]) + str(a.args[1]) elif hasattr(a, "name"): return a.name except AttributeError: return str(a)
[docs]def reduce_mul_args(args): """Sorts and reduces a multiply operation arguments. Collect constants and sort variables by their `str`. """ def _reduce_mul_args(a, b): b = as_symbol(b) # need this in case arg is a float/int type # Assumes first element is a numeric constant and # store all constants there if type(b) is Constant: if len(a) >= 1: if b.is_named: a[0].append(b) else: a[0][0] = Constant(a[0][0].value * b.value) else: a[0].append(b) elif type(b) is Matrix: a[1].append(b) else: a[0].append(b) return a def _reduce_exp_args(a, b): if hasattr(b, "op") and b.op == np.exp: a[1] += b.args[0] else: a[0].append(b) return a def _reduce_mul_args_pows(a, b): base, exp = base_exponent(b) if np.all(a[-1][0] == base): a[-1][1] += exp else: a.append([base, exp]) return a if len(args) > 1: scalars, matrices = reduce(_reduce_mul_args, args, ([Constant(1)], [])) if np.all(scalars[0] == 1): if len(scalars) == 1: return [1] else: scalars.pop(0) # Exp function is an unusual one because it is actually a power function # so we need to collect it as if it was non_exp_args, exp_arg = reduce(_reduce_exp_args, scalars, [[1], 0]) scalars = non_exp_args if exp_arg != 0: # if there is an exp arg then it needs to be added to the list # of non exp args scalars.append(np.exp(exp_arg)) # Sort arguments alphabetically using str form of symbols scalars.sort(key=mul_sort_key) # Now it's ordered, powers will be grouped # together so we can reduce them too pow_args = reduce( _reduce_mul_args_pows, scalars[1:], [list(base_exponent(scalars[0]))] ) # any 0**n or 1**n simplify out args = list(b**e for b, e in pow_args if e != 0 and b != 1) args.extend(matrices) return args
[docs]def MAKE_simplify_mul(dir): def simplify_mul(self, other): if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("*", operator.mul, self, other) else: return Function("*", operator.mul, other, self) # Check if an obvious simplication can be made here when 0 or 1 are used if isinstance(self, np.ndarray): if np.all(self == 0): return 0 elif np.all(self == 1): return other if isinstance(other, np.ndarray): if np.all(other == 0): return 0 elif np.all(other == 1): return self if isinstance(self, (Number, Constant)): if self == 0: return 0 elif self == 1: return other if isinstance(other, (Number, Constant)): if other == 0: return 0 elif other == 1: return self # If not extract out the arguments and sort them depending on whether it # is a LOP or ROP def process(a): try: if a.op is operator_mul: a = a.args elif a.op is operator.pos: a = [+1, *a.args] elif a.op is operator.neg: a = [-1, *a.args] else: a = (a,) except AttributeError: a = (a,) return a a = process(self) b = process(other) if dir == "LOP": args = [*a, *b] else: args = [*b, *a] args = reduce_mul_args(args) if len(args) > 1: # Create a n-arg multiplication return Function("*", operator_mul, *args) elif len(args) == 1: return args[0] # reduced to a single argument else: # 0 arguments left then it's cancelled out the args return Constant(1) return simplify_mul
[docs]def MAKE_simplify_sub(dir): _simplify_sub = MAKE_simplify_add(dir) def simplify_sub(self, other): if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("-", operator.sub, self, other) else: return Function("-", operator.sub, other, self) else: if dir == "LOP": return _simplify_sub(self, -other) else: return _simplify_sub(-self, other) return simplify_sub
[docs]def MAKE_simplify_neg(): def simplify_neg(self): if not _SIMPLIFICATION_ENABLED: return Function("-", operator.neg, self) else: # converts to a multiplication to make simplifying easier return -1 * self return simplify_neg
[docs]def MAKE_simplify_pos(): def simplify_pos(self): if not _SIMPLIFICATION_ENABLED: return Function("+", operator.pos, self) else: # no need for any extra operation here return self return simplify_pos
[docs]def MAKE_simplify_pow(dir): def simplify_pow(self, exp): if dir == "ROP": self, exp = exp, self if not _SIMPLIFICATION_ENABLED: return Function("**", operator.pow, self, exp) if hasattr(self, "op"): # If we're doing a power of a power then we can do some obvious # multiplication of the power terms. We override self and exp # so that the other simplification logic can then take place # i.e. (x**0.5)**2 -> x if self.op is operator.pow: z = self.args[0] exp = self.args[1] * exp self = z if np.all(exp == 0): return Constant(1) elif np.all(exp == 1): return self elif ( type(self) is Constant and type(exp) is Constant and not (self.is_named ^ exp.is_named) ): return Constant(self.value**exp.value) else: return Function("**", operator.pow, self, exp) return simplify_pow
[docs]def MAKE_LOP_simplify_truediv(): def simplify_truediv(self, other): if not _SIMPLIFICATION_ENABLED: return Function("/", operator.truediv, self, other) if np.all(self == 0): return 0 elif np.all(other == 0): raise ZeroDivisionError() elif isinstance(self, Function) and self.op is operator.neg: return -self.args[0] * Function("**", operator.pow, other, -1) elif isinstance(other, Function) and other.op is operator.neg: return -self * Function("**", operator.pow, other.args[0], -1) else: return self * Function("**", operator.pow, other, -1) return simplify_truediv
[docs]def MAKE_ROP_simplify_truediv(): def simplify_truediv(self, other): if not _SIMPLIFICATION_ENABLED: return Function("/", operator.truediv, other, self) if np.all(other == 0): return 0 elif self == 0: raise ZeroDivisionError() elif isinstance(self, Function) and self.op is operator.neg: return -other * Function("**", operator.pow, self.args[0], -1) elif isinstance(other, Function) and other.op is operator.neg: return -other.args[0] * Function("**", operator.pow, self, -1) else: return other * Function("**", operator.pow, self, -1) return simplify_truediv
# Supported operators. OPERATORS = { "__add__": MAKE_simplify_add("LOP"), "__sub__": MAKE_simplify_sub("LOP"), "__mul__": MAKE_simplify_mul("LOP"), "__radd__": MAKE_simplify_add("ROP"), "__rsub__": MAKE_simplify_sub("ROP"), "__rmul__": MAKE_simplify_mul("ROP"), "__neg__": MAKE_simplify_neg(), "__pos__": MAKE_simplify_pos(), "__pow__": MAKE_simplify_pow("LOP"), "__rpow__": MAKE_simplify_pow("ROP"), "__truediv__": MAKE_LOP_simplify_truediv(), "__rtruediv__": MAKE_ROP_simplify_truediv(), "__floordiv__": MAKE_LOP("//", operator.floordiv), "__rfloordiv__": MAKE_ROP("//", operator.floordiv), "__matmul__": MAKE_LOP("@", operator.matmul), "__mod__": MAKE_LOP("%", operator.mod), "__rmod__": MAKE_ROP("%", operator.mod), } # Maps function names to actual functions called, # this is used for lambdifying symbolics as you # can't grab the underlying function from the lambdas # stored in FUNCTIONS below PYFUNCTION_MAP = { "abs": operator.abs, "neg": operator.neg, "pos": operator.pos, "pow": operator.pow, "conj": np.conj, "real": np.real, "imag": np.imag, "exp": np.exp, "log10": np.log10, "log": np.log, "sin": np.sin, "arcsin": np.arcsin, "cos": np.cos, "arccos": np.arccos, "tan": np.tan, "arctan": np.arctan, "arctan2": np.arctan2, "sqrt": np.sqrt, "std": np.std, "sum": np.sum, "dot": np.dot, "radians": np.radians, "degrees": np.degrees, "deg2rad": np.deg2rad, "rad2deg": np.rad2deg, "arange": np.arange, "linspace": np.linspace, "logspace": np.logspace, "geomspace": np.geomspace, "jv": ufuncs.jv, } # Built-in symbolic functions: maps string names of functions to acutal functions FUNCTIONS = { "abs": lambda x: Function("abs", operator.abs, x), "neg": lambda x: Function("neg", operator.neg, x), "pos": lambda x: Function("pos", operator.pos, x), "pow": lambda x: Function("pow", operator.pow, x), "conj": lambda x: Function("conj", np.conj, x), "real": lambda x: Function("real", np.real, x), "imag": lambda x: Function("imag", np.imag, x), "exp": lambda x: Function("exp", np.exp, x), "log": lambda x: Function("log", np.log, x), "log10": lambda x: Function("log10", np.log10, x), "sin": lambda x: Function("sin", np.sin, x), "arcsin": lambda x: Function("arcsin", np.arcsin, x), "cos": lambda x: Function("cos", np.cos, x), "arccos": lambda x: Function("arccos", np.arccos, x), "tan": lambda x: Function("tan", np.tan, x), "arctan": lambda x: Function("arctan", np.arctan, x), "arctan2": lambda y, x: Function("arctan2", np.arctan2, y, x), "sqrt": lambda x: Function("sqrt", np.sqrt, x), "std": lambda x: Function("std", np.std, x), "sum": lambda x: Function("sum", np.sum, x), "dot": lambda x, y: Function("dot", np.dot, x, y), "radians": lambda x: Function("radians", np.radians, x), "degrees": lambda x: Function("degrees", np.degrees, x), "deg2rad": lambda x: Function("deg2rad", np.deg2rad, x), "rad2deg": lambda x: Function("rad2deg", np.rad2deg, x), "arange": lambda a, b, c: Function("arange", np.arange, float(a), float(b), int(c)), "linspace": lambda a, b, c: Function( "linspace", np.linspace, float(a), float(b), int(c) ), "logspace": lambda a, b, c: Function( "logspace", np.logspace, float(a), float(b), int(c) ), "geomspace": lambda a, b, c: Function( "geomspace", np.geomspace, float(a), float(b), int(c) ), "jv": lambda v, x: Function("jv", ufuncs.jv, v, x), } op_repr = { operator_add: lambda *args: "({})".format("+".join(args)), operator.add: "({}+{})".format, operator.sub: lambda *args: "({})".format("-".join(args)), operator_mul: lambda *args: "*".join(args), operator.mul: "({}*{})".format, operator.pow: "({})**({})".format, operator.truediv: "({}/{})".format, operator.floordiv: "({}//{})".format, operator.mod: "({}%{})".format, operator.matmul: "({}@{})".format, operator.neg: "-{}".format, operator.pos: "+{}".format, operator.abs: "abs({})".format, np.conj: "conj({})".format, np.sqrt: "sqrt({})".format, }
[docs]def display(a, dunder=()): """For a given Symbol this method will return a human readable string representing the various operations it contains. Parameters ---------- a : :class:`.Symbol` Symbol to print dunder : tuple Names of variables to display with double underscores pre- and suf-fixing the names. Returns ------- String form of Symbol """ if hasattr(a, "op"): # Check if operation has a predefined string format if a.op in op_repr: sop = op_repr[a.op] else: # if not just treat it as a function sop = (a.op.__name__ + "(" + ("{}," * len(a.args)).rstrip(",") + ")").format sargs = (display(_, dunder=dunder) for _ in a.args) return sop(*sargs).replace("-1*", "-").replace("+-", "-").replace("*1/", "/") elif hasattr(a, "name"): # Anything with a name attribute just display that if a in dunder: return "__" + a.name + "__" else: return a.name elif type(a) is Symbol: return f"<Symbol @ {hex(id(a))}>" else: return str(a)
[docs]def finesse2sympy(expr, iter_num=0): """ Notes ----- It might be common for this this function to throw a NotImplementedError. This function maps, by hand, various operator and numpy functions to sympy. If you come across this error, you'll need to update the if-statement to include the missing operations. Over time this should get fixed for most use cases. """ import sympy from finesse.parameter import ParameterRef iter_num += 1 if isinstance(expr, Constant): return expr.value elif isinstance(expr, (ParameterRef, Variable)): return sympy.Symbol(expr.name) elif isinstance(expr, Function): sympy_args = [finesse2sympy(arg, iter_num) for arg in expr.args] if expr.op == operator_mul or expr.op == operator.mul: op = sympy.Mul elif expr.op == operator_add or expr.op == operator.add: op = sympy.Add elif expr.op == operator.truediv: op = lambda a, b: sympy.Mul(a, sympy.Pow(b, -1)) elif expr.op == operator.mod: op = sympy.Mod elif expr.op == operator.pow: op = sympy.Pow elif expr.op == operator.sub: op = lambda x, y: sympy.Add(x, -y) elif expr.op == np.conj: op = sympy.conjugate elif expr.op == np.radians: op = sympy.rad elif expr.op == np.exp: op = sympy.exp elif expr.op == np.cos: op = sympy.cos elif expr.op == np.sin: op = sympy.sin elif expr.op == np.tan: op = sympy.tan elif expr.op == np.sqrt: op = sympy.sqrt elif expr.op == operator.abs: op = sympy.Abs elif expr.op == operator.neg: op = lambda x: sympy.Mul(-1, x) elif expr.op == operator.pos: op = lambda x: sympy.Mul(+1, x) elif expr.op == np.imag: op = sympy.im elif expr.op == np.real: op = sympy.re else: try: op = getattr(sympy, expr.op.__name__) except AttributeError: raise NotImplementedError( f"undefined Function {expr.op} in {expr}. {finesse2sympy} needs to be updated." ) return op(*sympy_args) else: raise NotImplementedError( f"{expr} undefined. {finesse2sympy} needs to be updated." )
[docs]def sympy2finesse(expr, symbol_dict=None, iter_num=0): import sympy symbol_dict = {} if symbol_dict is None else symbol_dict iter_num += 1 if isinstance(expr, sympy.Mul): return np.prod( [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args] ) elif isinstance(expr, sympy.Add): return np.sum( [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args] ) elif isinstance(expr, sympy.conjugate): return np.conj(sympy2finesse(*expr.args, symbol_dict)) elif isinstance(expr, sympy.exp): return np.exp(sympy2finesse(*expr.args, symbol_dict)) elif isinstance(expr, sympy.Pow): return np.power( sympy2finesse(expr.args[0], symbol_dict), sympy2finesse(expr.args[1], symbol_dict), ) elif isinstance(expr, sympy.Mod): return np.mod( sympy2finesse(expr.args[0], symbol_dict), sympy2finesse(expr.args[1], symbol_dict), ) elif ( expr.is_NumberSymbol ): # sympy class for named symbols (eg Pi, golden ratio, ...) if str(expr) == "pi": return CONSTANTS["pi"] else: return complex(expr) elif expr.is_number: if expr.is_integer: return int(expr) elif expr.is_real: return float(expr) else: return complex(expr) elif expr.is_symbol: return symbol_dict[str(expr)] else: raise Exception(f"{expr} undefined")
simplify_symbolic_numpy = np.vectorize( lambda x: collect(expand(x)) if isinstance(x, Symbol) else x, otypes="O" )
[docs]def np_eval_symbolic_numpy(a, *keep): if isinstance(a, Symbol): return a.eval(keep=keep) else: return a
__eval_symbolic_numpy = np.vectorize(np_eval_symbolic_numpy, otypes="O")
[docs]def eval_symbolic_numpy(a, *keep): return __eval_symbolic_numpy(a, *keep)
[docs]def as_symbol(x): if isinstance(x, Symbol): return x return Constant(x)
[docs]def evaluate(x): """Evaluates a symbol or N-dimensional array of symbols. Parameters ---------- x : :class:`.Symbol` or array-like A symbolic expression or an array of symbolic expressions. Returns ------- out : float, complex, :class:`numpy.ndarray` A single value for the evaluated expression if `x` is not array-like, otherwise an array of the evaluated expressions. """ if is_iterable(x): y = np.array(x, dtype=np.complex128) if not np.any(y.imag): # purely real symbols in array with warnings.catch_warnings(): if parse(np.__version__) < parse("1.25"): category = np.ComplexWarning else: category = np.exceptions.ComplexWarning # suppress 'casting to float discards imag part' warning # as we know that all imag parts are zero here anyway warnings.simplefilter("ignore", category=category) y = np.array(y, dtype=np.float64) return y if isinstance(x, Symbol): return x.eval() # If not a symbol then just return x directly return x
[docs]class Symbol(abc.ABC): __add__ = OPERATORS["__add__"] __sub__ = OPERATORS["__sub__"] __mul__ = OPERATORS["__mul__"] __radd__ = OPERATORS["__radd__"] __rsub__ = OPERATORS["__rsub__"] __rmul__ = OPERATORS["__rmul__"] __pow__ = OPERATORS["__pow__"] __rpow__ = OPERATORS["__rpow__"] __truediv__ = OPERATORS["__truediv__"] __rtruediv__ = OPERATORS["__rtruediv__"] __floordiv__ = OPERATORS["__floordiv__"] __rfloordiv__ = OPERATORS["__rfloordiv__"] __matmul__ = OPERATORS["__matmul__"] __neg__ = OPERATORS["__neg__"] __pos__ = OPERATORS["__pos__"] __mod__ = OPERATORS["__mod__"] __rmod__ = OPERATORS["__rmod__"] __abs__ = FUNCTIONS["abs"] conjugate = FUNCTIONS["conj"] conj = FUNCTIONS["conj"] exp = FUNCTIONS["exp"] sin = FUNCTIONS["sin"] arcsin = FUNCTIONS["arcsin"] cos = FUNCTIONS["cos"] arccos = FUNCTIONS["arccos"] tan = FUNCTIONS["tan"] arctan = FUNCTIONS["arctan"] arctan2 = FUNCTIONS["arctan2"] sqrt = FUNCTIONS["sqrt"] radians = FUNCTIONS["radians"] degrees = FUNCTIONS["degrees"] deg2rad = FUNCTIONS["deg2rad"] rad2deg = FUNCTIONS["rad2deg"] log = FUNCTIONS["log"] log10 = FUNCTIONS["log10"] @property def real(self): return FUNCTIONS["real"](self) @property def imag(self): return FUNCTIONS["imag"](self) def __eq__(self, obj): """Inheriting classes should implement __symeq__ and do any symbol specific equality checks there. This top level handles initial n-arg conversion before symbolic comparison. """ # Need to convert any symbolic expressions (Functions) # to nary form for equality checks if isinstance(obj, Function) and not obj._is_narg_expression_tree: obj = obj.to_nary_add_mul() if isinstance(self, Function) and not self._is_narg_expression_tree: self = self.to_nary_add_mul() # Might have been simplified to a number if isinstance(self, Number): return self == obj return self.__symeq__(obj) def __symeq__(self, obj): if isinstance(obj, Function): return obj == self else: return id(self) == id(obj) def __float__(self): v = self.eval() if np.isscalar(v): return float(v) else: raise TypeError(f"Can't cast {type(v)} ({v}) into a single float value") def __complex__(self): v = self.eval() if np.isscalar(v): return complex(v) else: raise TypeError(f"Can't cast {type(v)} ({v}) into a single complex value") def __int__(self): v = self.eval() if np.isscalar(v): return int(v) else: raise TypeError(f"Can't cast {type(v)} into a single int value") @property def value(self): "The current value of this symbol" return self.eval() def __str__(self): return display(self) def __repr__(self): return f"<Symbolic='{display(self)}' @ {hex(id(self))}>" def __bool__(self): return bool(self.eval()) @property def is_changing(self): """Returns True if one of the arguements of this symbolic object is varying whilst a :class:`` is running.""" res = False if hasattr(self, "op"): res = any([_.is_changing for _ in self.args]) elif hasattr(self, "parameter"): res = self.parameter.is_tunable or self.parameter.is_changing return res
[docs] def parameters(self, memo=None): """Returns all the parameters that are present in this symbolic expression. Parameters are symbols whose values are attached to a model """ if memo is None: memo = OrderedSet() if hasattr(self, "op"): for _ in self.args: _.parameters(memo) elif hasattr(self, "parameter"): memo.add(self) return list(memo)
[docs] def all(self, predicate, memo=None): """Returns all the symbols that are present in this expression which satisify the predicate. Parameters ---------- predicate : callable Method which takes in an argument and returns True if it matches. Examples -------- To select all `Constant`s and `Variable`s from an expression `y`: >>> y.all(lambda a: isinstance(a, (Constant, Variable))) """ if memo is None: memo = OrderedSet() if hasattr(self, "op"): for _ in self.args: _.all(predicate, memo) if predicate(self): memo.add(self) return list(memo)
[docs] def changing_parameters(self): p = np.array(self.parameters()) return list(p[list(map(lambda x: x.is_changing, p))])
[docs] def to_sympy(self): """Converts a Finesse symbolic expression into a Sympy expression. Warning: for large functions this can be quite slow. """ return finesse2sympy(self)
[docs] def sympy_simplify(self): """Converts this expression into a Sympy symbol.""" refs = { _.name: _ for _ in self.parameters() } # get a list of symbols we're using sympy = finesse2sympy(self) return sympy2finesse(sympy.simplify(), refs)
[docs] def collect(self): """Collects like terms in the expressions.""" return collect(self)
[docs] def expand(self): """Performs a basic expansion of the symbolic expression.""" return expand(self)
[docs] def expand_symbols(self): """A method that expands any symbolic parameter references that are themselves symbolic. This can be used to get an expression that only depends on references that are numeric. Examples -------- >>> import finesse >>> model = finesse.Model() >>> model.parse( ... ''' ... var d 300 ... var c 6000 ... var b c+d ... var a b+1 ... ''' ... ) >>> model.a.value.expand_symbols() <Symbolic='((c+d)+1)' @ 0x7faa4d351c10> Parameters ---------- sym : Symbolic Symbolic equation to expand """ def process(p): if p.parameter.is_symbolic: return p.parameter.value else: return p def _expand(sym): params = sym.parameters() if len(params) == 0: return None elif not any(p.parameter.is_symbolic for p in params): return None else: subs = {p: process(p) for p in params if p.parameter.is_symbolic} return sym.substitute(subs) sym = self while True: res = _expand(sym) if res is None: return sym else: sym = res
[docs] def to_binary_add_mul(self): """Converts a symbolic expression to use binary forms of operator.add and operator.mul operators, rather than the n-ary operator_add and operator_mul. Returns ------- Symbol """ if hasattr(self, "op"): return self.op(*(_.to_binary_add_mul() for _ in self.args)) else: return self
[docs] def to_nary_add_mul(self): """Converts a symbolic expression to use n-ary forms of operator_add and operator_mul operators, rather than the binary-ary operator.add and operator.mul. Returns ------- Symbol """ if hasattr(self, "op"): with simplification(allow_flagged=True): if self.op is operator.pos: return self.args[0].to_nary_add_mul() elif self.op is operator.neg: return -1 * self.args[0].to_nary_add_mul() else: # calling to_nary_add_mul here as it's basically the same but # just won't keep opening a new context manager each time return self.op(*(_.to_nary_add_mul() for _ in self.args)) else: return self
[docs] def lambdify(self, *args, expand_symbols=False, ignore_unused_symbols=False): """Converts this symbolic expression into a function that can be called. Parameters ---------- args : Symbols Symbols to use to make up the arguments of the generated function. If none are provided then the current values of `ParameterRef`s are used and `Variables` are left as they are. expand_symbols : bool, optional If True, the expression will first have any dependent variables expanded. See `.expand_symbols`. """ from finesse.parameter import ParameterRef # Convert to a string form, which is in a python evaluatable format, then # run eval on it to get a lambda function if expand_symbols: expr = self.expand_symbols() else: expr = self # We dunder the arugments so we can easily find and string replace them # later with the lambda function argument. This has to be done as `.` # can't be used in variable name (i.e. l1.P). This should stop clashes # with other names. sym_str = display(expr, dunder=args) params = expr.all(lambda x: isinstance(x, (Variable, ParameterRef))) fix_curly = lambda x: x.replace("{", "Ç").replace("}", "ç") ARGS = [] for arg in args: if not ignore_unused_symbols and arg not in params: raise NameError( f"`{arg}` is not a valid symbol to make as an argument to this the lambda function in this expression: {sym_str}" ) if hasattr(arg, "full_name"): ARGS.append(arg.full_name.replace(".", "__")) ARGS[-1] = fix_curly(ARGS[-1]) sym_str = sym_str.replace("__" + arg.full_name + "__", ARGS[-1]) else: ARGS.append(arg.name.replace(".", "__")) ARGS[-1] = fix_curly(ARGS[-1]) sym_str = sym_str.replace("__" + arg.name + "__", ARGS[-1]) _globals = {} for arg in params: if arg not in args: if isinstance(arg, Variable): _globals[arg.name] = arg elif hasattr(arg, "owner") and hasattr(arg.owner, "name"): _globals[arg.owner.name] = arg.owner else: # for model parameters _globals[arg.name] = arg.owner.get(arg.name) # get functions used so that they can be exposed to the eval for func in expr.all(lambda x: isinstance(x, Function)): if func.__class__.__module__ != "__builtin__": _globals[func.name] = func.op s = f"lambda {','.join(ARGS)}: {sym_str}" return eval(s, _globals)
@staticmethod def _check_substitution(subs): if "__checked__" not in subs: # Need this check to pre-subsitute any constants back into constants = {k: v for k, v in subs.items() if not hasattr(v, "substitute")} if len(constants) > 0: for key in list(subs.keys()): if hasattr(subs[key], "substitute"): subs[key] = subs[key].substitute(constants) subs["__checked__"] = True
[docs] def substitute(self, mapping): """Uses a dictionary to substitute terms in this expression with another. This does not perform any evaluation of any terms, unlike `eval(subs=...)`. Notes ----- The symbolic substitution implemented here is not recursive, consider: >>> y = a + b >>> y.subs({a:a+b, b:a}) # results in => a+b+a Here `b` is not replaced in the substitutions. The only time this happens is if one mapping is purely numeric: >>> y.subs({a:a+b, b:1}) # results in => a+2 Parameters ---------- mapping : dict Dictionary of substitutions/mappings to make. Keys can be the actual symbol or the name of the symbol in string form. Values must all be proper symbols. """ self._check_substitution(mapping) if self.name in mapping: return mapping[self.name] elif self in mapping: return mapping[self] else: return self
[docs]class Matrix(Symbol): """A Matrix symbol."""
[docs] def __init__(self, name): self.name = str(name)
def __hash__(self): return hash((self.name,))
[docs] def eval(self, **kwargs): return self
[docs]class Constant(Symbol): """Defines a constant symbol that can be used in symbolic math. Parameters ---------- value : float, int Value of constant name : str, optional Name of the constant to use when printing """
[docs] def __init__(self, value, name=None): self.__value = value self.__name = name
def __str__(self): return self.__name or str(self.__value) def __repr__(self): return str(self.__name or self.__value) def __symeq__(self, obj): if isinstance(obj, Function): return obj == self.value elif isinstance(obj, Constant): return obj.value == self.value else: return obj == self.value def __hash__(self): """Constant hash. This is used by the tokenizer. Constants are by definition immutable. """ # Add the class to reduce chance of hash collisions. return hash((type(self), self.value))
[docs] def substitute(self, subs): """Uses a dictionary to substitute terms in this expression with another. This does not perform any evaluation of any terms, unlike `eval(subs=...)`. Parameters ---------- subs : dict Dictionary of substitutions to make. Keys can be the actual symbol or the name of the symbol in string form. Values must all be proper symbols. """ self._check_substitution(subs) if subs and self in subs: return subs[self] else: return self
[docs] def eval(self, subs=None, **kwargs): """Evaluate this constant. If a substitution is available the value of that will be used instead of `self.value` """ if subs and self in subs: return subs[self] elif hasattr(self.__value, "eval"): return self.__value.eval() else: return self.__value
@property def is_named(self): """Was this constant given a specific name.""" return self.__name is not None @property def name(self): return str(self.value) if self.__name is None else self.__name
# Constants. # NOTE: The keys here are used by the parser to recognise constants in kat script. CONSTANTS = { "pi": Constant(constants.PI, name="π"), "c0": Constant(constants.C_LIGHT, name="c"), }
[docs]class Resolving(Symbol): """A special symbol that represents a symbol that is not yet resolved. This is used in the parser to support self-referencing parameters. An error is thrown if the value is attempted to be read. """
[docs] def eval(self, **kwargs): raise RuntimeError( "an attempt has been made to read the value of a resolving symbol (hint: " "symbols should not be evaluated until parsing has fully finished)" )
@property def name(self): return "RESOLVING"
# IMPORTANT: renaming this class impacts the katscript spec and should be avoided!
[docs]class Variable(Symbol): """Makes a variable symbol that can be used in symbolic math. Values must be substituted in when evaluating an expression. Examples -------- Using some variables to make an expression and evaluating it: >>> import numpy as np >>> x = Variable('x') >>> y = Variable('y') >>> z = 4*x**2 - np.cos(y) >>> print(f"{z} = {z.eval(subs={x:2, y:3})} : x={2}, y={3}") (4*x**2-y) = 13 : x=2, y=3 Parameters ---------- value : float, int Value of constant name : str, optional Name of the constant to use when printing """
[docs] def __init__(self, name): if name is None: raise ValueError("Name must be provided") self.__name = str(name)
def __hash__(self): return hash((Variable, self.name)) @property def name(self): return self.__name def __symeq__(self, obj): if isinstance(obj, Variable): return self.name == obj.name else: return False
[docs] def eval(self, subs=None, keep=None, **kwargs): """Evaluates this variable and returns either itself or a substituted value. Parameters ---------- subs : dict Dictionary of object keep : iterable, str A collection of names of variables to keep as variables when evaluating. Keep will override any substitution. """ if keep: if self.name == keep: return self elif is_iterable(keep): if self.name in keep or self in keep: return self if subs is not None: return self.substitute(subs) else: return self
[docs]class Function(Symbol): """This is a symbol to represent a mathematical function. This could be a simple addition, or a more complicated multi-argument function. It supports creating new mathematical operations:: import math import cmath cos = lambda x: finesse.symbols.Function("cos", math.cos, x) sin = lambda x: finesse.symbols.Function("sin", math.sin, x) atan2 = lambda y, x: finesse.symbols.Function("atan2", math.atan2, y, x) Complex math can also be used:: import numpy as np angle = lambda x: finesse.symbols.Function("angle", np.angle, x) print(f"{angle(1+1j)} = {angle(1+1j).eval()}") Parameters ---------- name : str The operation name. This is used for dumping operations to kat script. operation : callable The function to pass the arguments of this operation to. Other Parameters ---------------- *args The arguments to pass to `operation` during a call. """
[docs] def __init__(self, name, operation, *args): self.__simplified_expr_tree = _SIMPLIFICATION_ENABLED self.name = str(name) self.args = list(as_symbol(_) for _ in args) self.op = operation
[docs] def substitute(self, subs): """Uses a dictionary to substitute terms in this expression with another. This does not perform any evaluation of any terms, unlike `eval(subs=...)`. Parameters ---------- subs : dict Dictionary of substitutions to make. Keys can be the actual symbol or the name of the symbol in string form. Values must all be proper symbols. """ try: args = tuple(_.substitute(subs) for _ in self.args) return self.op(*args) except ZeroDivisionError: return np.nan # return except TypeError as ex: msg = f"Whilst evaluating {self.op}{self.args} this error was raised:\n` {ex}`" for _ in self.args: if _.value is None: msg += f"\nHint: {_} is `None` make sure it is defined before being used." if _.parameter.full_name == "fsig.f": msg += " if using KatScript use `fsig(f)` before this line." raise FinesseException(msg)
[docs] def eval(self, **kwargs): """Evaluates the operation. Parameters ---------- subs : dict, optional Parameter substitutions can be given via an optional ``subs`` dict (mapping parameters to substituted values). keep : iterable, str A collection of names of variables to keep as variables when evaluating. Notes ----- A division by zero will return a NaN, rather than raise an exception. Returns ------- result : number or array-like The single-valued result of evaluation of the operation (if no substitutions given, or all substitutions are scalar-valued). Otherwise, if any parameter substitution was a :class:`numpy.ndarray`, then a corresponding array of results. """ try: args = tuple(_.eval(**kwargs) for _ in self.args) return self.op(*args) except ZeroDivisionError: return np.nan # return except TypeError as ex: msg = f"Whilst evaluating {self.op}{self.args} this error was raised:\n` {ex}`" for _ in self.args: if _.value is None: msg += f"\nHint: {_} is `None` make sure it is defined before being used." if _.parameter.full_name == "fsig.f": msg += " if using KatScript use `fsig(f)` before this line." raise FinesseException(msg)
@property def _is_narg_expression_tree(self): """Was this expressions built with the global `_SIMPLIFICATION_ENABLED` set to `False`.""" return self.__simplified_expr_tree @property def contains_unresolved_symbol(self): """Whether the operation contains any unresolved symbols. :`getter`: Returns true if any symbol in the operation is an instance of :class:`.Resolving`, false otherwise. Read-only. """ for arg in self.args: if isinstance(arg, Resolving): return True if isinstance(arg, Function): return arg.contains_unresolved_symbol return False def __symeq__(self, obj): if not isinstance(obj, Symbol) and (obj == 0 or obj == 1): # Some dumb checks for common comparisons for things like 0 return False if self is obj: return True # Need to simplify functions first before comparisons A = self B = obj if not isinstance(A, Function): # A might not be a function anymore if simplified away to variable # 2*a/2 -> a return A == obj if isinstance(B, Function): if A.op == B.op: if len(A.args) != len(B.args): return False else: return all([a == b for a, b in zip(A.args, B.args)]) if B.op is operator.pos: return A == B.args[0] elif B.op is operator.neg: return A == -1 * B.args[0] else: return False else: # pos and neg comparisons are some what annoying, if we # convert them to multiply it becomes a simple product arg # comparison if A.op is operator.pos: return A.args[0] == B elif A.op is operator.neg: return (-1) * A.args[0] == B else: return False def __hash__(self): return hash((Function, self.op, *(hash(_) for _ in self.args)))
[docs]class LazySymbol(Symbol): """A generic way to make some lazily evaluated symbol. The value is dependant on a lambda function and some arbitrary arguments which will be called when the symbol is evaluated. Parameters ---------- name : str Human readable string name for the symbol function : callable Function to call when evaluating the symbol *args : objects Arguments to pass to `function` when evaluating Examples -------- >>> a = LazyVariable('a', lambda x: x**2, 10) >>> print(a) <Symbolic='(a*10)' @ 0x7fd6587e6760> >>> print((a*10).eval()) 1000 """
[docs] def __init__(self, name, function, *args): self.__name = name self.function = function self.args = args
[docs] def eval(self, **kwargs): return self.function(*self.args)
@property def name(self): return self.__name
[docs]def collect(y): if hasattr(y, "op"): if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if not hasattr(y, "op"): return y with simplification(allow_flagged=True): if y.op is operator_add: args = {} out = 0 for x in y.args: c, t = coefficient_and_term(x) if t is None: # just a coefficient/constant out += x else: if t in args: args[t] += c else: args[t] = c for k, v in args.items(): out += k * v return out elif len(y.args) > 0: # y is some other operator, which may have args that need collecting # pos and neg we can simplify here to just return itself, or # by multiplying by -1 if y.op is operator.pos: return collect(y.args[0]) elif y.op is operator.neg: return -1 * collect(y.args[0]) else: return y.op(*(collect(_) for _ in y.args)) else: return y else: return y
[docs]def coefficient_and_term(y): try: if y.op is operator_mul or y.op is operator.mul: if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees args = y.args if all(type(_) is Constant for _ in args): # if all constants then probably a named constant time other # like 2*pi coeff = y term = None elif type(args[0]) is not Constant: # There are no constant values so just variables/functions/etc coeff = Constant(1) term = np.prod(args) else: coeff = args[0] # Just the constants at the start term = np.prod( args[1:] ) # then the rest of the args should be variables/etc return coeff, term elif y.op is operator.pos: return Constant(1), y.args[0] elif y.op is operator.neg: return Constant(-1), y.args[0] else: return Constant(1), y except AttributeError: # Not an operator... if type(y) is Constant: return y, None else: return Constant(1), y
[docs]def expand_mul(y): with simplification(allow_flagged=True): if not hasattr(y, "op"): return y elif not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if y.op is operator_mul: terms = np.array([1], dtype="O") for x in y.args: try: if x.op is operator_add: terms = np.outer(terms, x.args) else: terms *= x except AttributeError: terms *= x res = np.sum(terms) return res elif y.op is operator_add: return sum(expand_mul(_) for _ in y.args) else: return y
[docs]def is_integer(n): """Checks if `n` is an integer. Parameters ---------- n : str, float Input to check """ try: float(n) except ValueError: return False else: return float(n).is_integer()
[docs]def expand_pow(y): with simplification(allow_flagged=True): if not hasattr(y, "op"): return y elif not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if y.op is operator.pow: z = 1 exp = y.args[1] base = y.args[0] try: if base.op is operator_mul: for _ in base.args: z *= _**exp return z elif base.op is operator_add: if type(exp) is Constant and is_integer(exp.value): n = int(exp.value) if n == 0: return 1 terms = np.array([1], dtype="O") for _ in range(abs(n)): terms = np.outer(terms, base.args) if n > 0: return terms.sum().collect() else: return (terms.sum().collect()) ** -1 else: return y else: return y except AttributeError: return y elif y.op is operator_add: return sum(expand_pow(_) for _ in y.args) elif y.op is operator_mul: return np.prod(tuple(expand_pow(_) for _ in y.args)) else: return y
[docs]def expand(y): with simplification(allow_flagged=True): if not isinstance(y, Function): return y # Nothing to expand else: if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if not isinstance(y, Function): return y # Nothing to expand if y.op is operator_mul: # Run through and expand any pows before the # full mul expansion y = Function(y.name, y.op, *[expand_pow(_) for _ in y.args]) y = expand_mul(y) # Nothing left to expand as no add operator if not isinstance(y, Function): return y elif y.op is operator_mul: return y elif y.op is operator.pow: y = expand_pow(y) # Nothing left to expand as no add operator if y.op is operator.pow: return y if y.op is not operator_add: # some other type of operator expand it's args if len(y.args) > 0: return y.op(*(expand(_) for _ in y.args)) else: return y else: # We have a summation of terms from the initial expansion # so expand each term out = [] for x in y.args: try: if x.op is operator_mul: x = Function(x.name, x.op, *[expand_pow(_) for _ in x.args]) z = expand_mul(x) elif x.op is operator.pow: z = expand_pow(x) else: z = x if hasattr(z, "op") and z.op is operator_add: out.extend(z.args) else: out.append(z) except AttributeError: out.append(x) z = Function(y.name, y.op, *out) if y == z: return z else: return expand(z)