Source code for finesse.symbols

"""Symbolic manipulations (`expand`, `collect`, etc.) are based on the book:

Cohen JS. Computer algebra and symbolic computation: mathematical methods. First
edition, 2003.
"""

import warnings
import abc

import numpy as np

from finesse import constants
from finesse.utilities import is_iterable
from finesse.exceptions import FinesseException

from contextlib import contextmanager
from functools import reduce
import operator
import logging


LOGGER = logging.getLogger(__name__)

MAKE_LOP = lambda name, opfn: lambda self, other: Function(name, opfn, self, other)
MAKE_ROP = lambda name, opfn: lambda self, other: Function(name, opfn, other, self)

# Default no simplification happens, must wrap symbol code that needs it
# with the context manager below. When flagged symbols will attempt to simplify
# themselves, using a standardised ordering of arguments. Only +, *, and ** operators
# are used in simplified symbols which allow various symbolic simplification
# algorithms to be used.
_SIMPLIFICATION_ENABLED = False  # Global flag to state if simplification will happen


[docs]@contextmanager def simplification(allow_flagged=False): """When used any symbol operations will use simplification methods.""" global _SIMPLIFICATION_ENABLED if allow_flagged and _SIMPLIFICATION_ENABLED: yield else: if _SIMPLIFICATION_ENABLED is True: raise RuntimeError("Simplification has already been enabled") try: _SIMPLIFICATION_ENABLED = True yield finally: _SIMPLIFICATION_ENABLED = False
[docs]def base_exponent(y): try: if y.op is operator.pow: return y.args[0], y.args[1] else: return y, Constant(1) except AttributeError: return y, Constant(1)
[docs]def operator_add(*args): return reduce(operator.add, args)
[docs]def operator_mul(*args): return reduce(operator.mul, args)
[docs]def operator_sub(*args): return reduce(operator.sub, args)
[docs]def MAKE_simplify_add(dir): def simplify_add(self, other): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("+", operator.add, self, other) else: return Function("+", operator.add, other, self) if np.all(self == 0): return other elif np.all(other == 0): return self elif type(self) is Constant and type(other) is Constant: if self.is_named and other.is_named: if self == other: return 2 * self else: return self + other elif self.is_named ^ other.is_named: if dir == "LOP": return Function("+", operator_add, self, other) else: return Function("+", operator_add, other, self) else: return Constant(self.value + other.value) else: def process(a): try: if a.op is operator_add: a = a.args else: a = (a,) except AttributeError: a = (a,) return a a = process(self) b = process(other) def sort_key(a): try: # chr(58) is the character after 9 so # numerical numbers come first. if a.op is operator_add: return chr(58) elif a.op is operator.pow: return chr(58) + str(a.args[0]) + str(a.args[1]) else: return str(a) except AttributeError: return a.name f = Function("+", operator_add, *a, *b) f.args.sort(key=sort_key) return f return simplify_add
[docs]def reduce_mul_args(args): """Sorts and reduces a multiply operation arguments. Collect constants and sort variables by their `str`. """ def _reduce_mul_args(a, b): b = as_symbol(b) # need this in case arg is a float/int type # Assumes first element is a numeric constant and # store all constants there if type(b) is Constant: if len(a) >= 1: if b.is_named: a[0].append(b) else: a[0][0] = Constant(a[0][0].value * b.value) else: a[0].append(b) elif type(b) is Matrix: a[1].append(b) else: a[0].append(b) return a def _reduce_mul_args_pows(a, b): base, exp = base_exponent(b) if np.all(a[-1][0] == base): a[-1][1] += exp else: a.append([base, exp]) return a def sort_key(a): try: # chr(58) is the character after 9 so # numerical numbers come first. if a.op is operator_add: return chr(58) elif a.op is operator.pow: return chr(58) + str(a.args[0]) + str(a.args[1]) else: return str(a) except AttributeError: return a.name if len(args) > 1: scalars, matrices = reduce(_reduce_mul_args, args, ([Constant(1)], [])) if np.all(scalars[0] == 1): if len(scalars) == 1: return [1] else: scalars.pop(0) # Sort arguments alphabetically using str form of symbols scalars.sort(key=sort_key) # Now it's ordered, powers will be grouped # together so we can reduce them too pow_args = reduce( _reduce_mul_args_pows, scalars[1:], [list(base_exponent(scalars[0]))] ) args = list(b**e for b, e in pow_args) args.extend(matrices) return args
[docs]def MAKE_simplify_mul(dir): def simplify_mul(self, other): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("*", operator.mul, self, other) else: return Function("*", operator.mul, other, self) if np.all(self == 0): return 0 elif np.all(other == 0): return 0 elif np.all(self == 1): return other elif np.all(other == 1): return self else: def process(a): try: if a.op is operator_mul: a = a.args else: a = (a,) except AttributeError: a = (a,) return a a = process(self) b = process(other) if dir == "LOP": args = [*a, *b] else: args = [*b, *a] args = reduce_mul_args(args) if len(args) > 1: return Function("*", operator_mul, *args) else: return args[0] return simplify_mul
[docs]def MAKE_simplify_sub(dir): _simplify_sub = MAKE_simplify_add(dir) def simplify_sub(self, other): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: if dir == "LOP": return Function("-", operator.sub, self, other) else: return Function("-", operator.sub, other, self) if dir == "LOP": return _simplify_sub(self, -other) else: return _simplify_sub(-self, other) return simplify_sub
[docs]def MAKE_simplify_neg(): def simplify_neg(self): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: return Function("-", operator.neg, self) else: return -1 * self return simplify_neg
[docs]def MAKE_simplify_pow(): def simplify_pow(self, exp): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: return Function("**", operator.pow, self, exp) if np.all(exp == 0): return Constant(1) elif np.all(exp == 1): return self elif ( type(self) is Constant and type(exp) is Constant and not (self.is_named ^ exp.is_named) ): return Constant(self.value**exp.value) else: return Function("**", operator.pow, self, exp) return simplify_pow
[docs]def MAKE_LOP_simplify_truediv(): def simplify_truediv(self, other): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: return Function("/", operator.truediv, self, other) if np.all(self == 0): return 0 elif np.all(other == 0): raise ZeroDivisionError() elif isinstance(self, Function) and self.op is operator.neg: return -self.args[0] * Function("**", operator.pow, other, -1) elif isinstance(other, Function) and other.op is operator.neg: return -self * Function("**", operator.pow, other.args[0], -1) else: return self * Function("**", operator.pow, other, -1) return simplify_truediv
[docs]def MAKE_ROP_simplify_truediv(): def simplify_truediv(self, other): global _SIMPLIFICATION_ENABLED if not _SIMPLIFICATION_ENABLED: return Function("/", operator.truediv, other, self) if np.all(other == 0): return 0 elif self == 0: raise ZeroDivisionError() elif isinstance(self, Function) and self.op is operator.neg: return -other * Function("**", operator.pow, self.args[0], -1) elif isinstance(other, Function) and other.op is operator.neg: return -other.args[0] * Function("**", operator.pow, self, -1) else: return other * Function("**", operator.pow, self, -1) return simplify_truediv
# Supported operators. OPERATORS = { "__add__": MAKE_simplify_add("LOP"), "__sub__": MAKE_simplify_sub("LOP"), "__mul__": MAKE_simplify_mul("LOP"), "__radd__": MAKE_simplify_add("ROP"), "__rsub__": MAKE_simplify_sub("ROP"), "__rmul__": MAKE_simplify_mul("ROP"), "__neg__": MAKE_simplify_neg(), "__pow__": MAKE_simplify_pow(), "__truediv__": MAKE_LOP_simplify_truediv(), "__rtruediv__": MAKE_ROP_simplify_truediv(), "__floordiv__": MAKE_LOP("//", operator.floordiv), "__rfloordiv__": MAKE_ROP("//", operator.floordiv), "__matmul__": MAKE_LOP("@", operator.matmul), } # Maps function names to actual functions called, # this is used for lambdifying symbolics as you # can't grab the underlying function from the lambdas # stored in FUNCTIONS below PYFUNCTION_MAP = { "abs": operator.abs, "neg": operator.neg, "pos": operator.pos, "pow": operator.pow, "conj": np.conj, "real": np.real, "imag": np.imag, "exp": np.exp, "log10": np.log10, "log": np.log, "sin": np.sin, "arcsin": np.arcsin, "cos": np.cos, "arccos": np.arccos, "tan": np.tan, "arctan": np.arctan, "arctan2": np.arctan2, "sqrt": np.sqrt, "std": np.std, "sum": np.sum, "dot": np.dot, "radians": np.radians, "degrees": np.degrees, "deg2rad": np.deg2rad, "rad2deg": np.rad2deg, "arange": np.arange, "linspace": np.linspace, "logspace": np.logspace, "geomspace": np.geomspace, } # Built-in symbolic functions: maps string names of functions to acutal functions FUNCTIONS = { "abs": lambda x: Function("abs", operator.abs, x), "neg": lambda x: Function("neg", operator.neg, x), "pos": lambda x: Function("pos", operator.pos, x), "pow": lambda x: Function("pow", operator.pow, x), "conj": lambda x: Function("conj", np.conj, x), "real": lambda x: Function("real", np.real, x), "imag": lambda x: Function("imag", np.imag, x), "exp": lambda x: Function("exp", np.exp, x), "log": lambda x: Function("log", np.log, x), "log10": lambda x: Function("log10", np.log10, x), "sin": lambda x: Function("sin", np.sin, x), "arcsin": lambda x: Function("arcsin", np.arcsin, x), "cos": lambda x: Function("cos", np.cos, x), "arccos": lambda x: Function("arccos", np.arccos, x), "tan": lambda x: Function("tan", np.tan, x), "arctan": lambda x: Function("arctan", np.arctan, x), "arctan2": lambda y, x: Function("arctan2", np.arctan2, y, x), "sqrt": lambda x: Function("sqrt", np.sqrt, x), "std": lambda x: Function("std", np.std, x), "sum": lambda x: Function("sum", np.sum, x), "dot": lambda x, y: Function("dot", np.dot, x, y), "radians": lambda x: Function("radians", np.radians, x), "degrees": lambda x: Function("degrees", np.degrees, x), "deg2rad": lambda x: Function("deg2rad", np.deg2rad, x), "rad2deg": lambda x: Function("rad2deg", np.rad2deg, x), "arange": lambda a, b, c: Function("arange", np.arange, float(a), float(b), int(c)), "linspace": lambda a, b, c: Function( "linspace", np.linspace, float(a), float(b), int(c) ), "logspace": lambda a, b, c: Function( "logspace", np.logspace, float(a), float(b), int(c) ), "geomspace": lambda a, b, c: Function( "geomspace", np.geomspace, float(a), float(b), int(c) ), } op_repr = { operator_add: lambda *args: "({})".format("+".join(args)), operator.add: "({}+{})".format, operator.sub: lambda *args: "({})".format("-".join(args)), operator_mul: lambda *args: "*".join(args), operator.mul: "({}*{})".format, operator.pow: "({})**({})".format, operator.truediv: "{}/{}".format, operator.floordiv: "{}//{}".format, operator.mod: "({}%{})".format, operator.matmul: "({}@{})".format, operator.neg: "-{}".format, operator.pos: "+{}".format, operator.abs: "abs({})".format, np.conj: "conj({})".format, np.sqrt: "sqrt({})".format, }
[docs]def display(a): """For a given Symbol this method will return a human readable string representing the various operations it contains. Parameters ---------- a : :class:`.Symbol` Symbol to print Returns ------- String form of Symbol """ if hasattr(a, "op"): # Check if operation has a predefined string format if a.op in op_repr: sop = op_repr[a.op] else: # if not just treat it as a function sop = (a.op.__name__ + "(" + ("{}," * len(a.args)).rstrip(",") + ")").format sargs = (display(_) for _ in a.args) return sop(*sargs).replace("-1*", "-").replace("+-", "-").replace("*1/", "/") elif hasattr(a, "name"): # Anything with a name attribute just display that return a.name elif type(a) is Symbol: return f"<Symbol @ {hex(id(a))}>" else: return str(a)
[docs]def finesse2sympy(expr, iter_num=0): """ Notes ----- It might be common for this this function to throw a NotImplementedError. This function maps, by hand, various operator and numpy functions to sympy. If you come across this error, you'll need to update the if-statement to include the missing operations. Over time this should get fixed for most use cases. """ import sympy from finesse.parameter import ParameterRef iter_num += 1 if isinstance(expr, Constant): return expr.value elif isinstance(expr, ParameterRef): return sympy.Symbol(expr.name) elif isinstance(expr, Function): sympy_args = [finesse2sympy(arg, iter_num) for arg in expr.args] if expr.op == operator_mul or expr.op == operator.mul: op = sympy.Mul elif expr.op == operator_add or expr.op == operator.add: op = sympy.Add elif expr.op == operator.truediv: op = lambda a, b: sympy.Mul(a, sympy.Pow(b, -1)) elif expr.op == operator.pow: op = sympy.Pow elif expr.op == operator.sub: op = lambda x, y: sympy.Add(x, -y) elif expr.op == np.conj: op = sympy.conjugate elif expr.op == np.radians: op = sympy.rad elif expr.op == np.exp: op = sympy.exp elif expr.op == np.cos: op = sympy.cos elif expr.op == np.sin: op = sympy.sin elif expr.op == np.tan: op = sympy.tan elif expr.op == np.sqrt: op = sympy.sqrt elif expr.op == operator.abs: op = sympy.Abs elif expr.op == operator.neg: op = lambda x: sympy.Mul(-1, x) elif expr.op == np.imag: op = sympy.im elif expr.op == np.real: op = sympy.re else: try: op = getattr(sympy, expr.op.__name__) except AttributeError: raise NotImplementedError( f"undefined Function {expr.op} in {expr}. {finesse2sympy} needs to be updated." ) return op(*sympy_args) else: raise NotImplementedError( f"{expr} undefined. {finesse2sympy} needs to be updated." )
[docs]def sympy2finesse(expr, symbol_dict={}, iter_num=0): import sympy iter_num += 1 if isinstance(expr, sympy.Mul): return np.product( [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args] ) elif isinstance(expr, sympy.Add): return np.sum( [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args] ) elif isinstance(expr, sympy.conjugate): return np.conj(sympy2finesse(*expr.args, symbol_dict)) elif isinstance(expr, sympy.exp): return np.exp(sympy2finesse(*expr.args, symbol_dict)) elif isinstance(expr, sympy.Pow): return np.power( sympy2finesse(expr.args[0], symbol_dict), sympy2finesse(expr.args[1], symbol_dict), ) elif ( expr.is_NumberSymbol ): # sympy class for named symbols (eg Pi, golden ratio, ...) if str(expr) == "pi": return CONSTANTS["pi"] else: return complex(expr) elif expr.is_number: if expr.is_integer: return int(expr) elif expr.is_real: return float(expr) else: return complex(expr) elif expr.is_symbol: return symbol_dict[str(expr)] else: raise Exception(f"{expr} undefined")
simplify_symbolic_numpy = np.vectorize( lambda x: collect(expand(x)) if isinstance(x, Symbol) else x, otypes="O" )
[docs]def np_eval_symbolic_numpy(a, *keep): if isinstance(a, Symbol): return a.eval(keep=keep) else: return a
__eval_symbolic_numpy = np.vectorize(np_eval_symbolic_numpy, otypes="O")
[docs]def eval_symbolic_numpy(a, *keep): return __eval_symbolic_numpy(a, *keep)
[docs]def as_symbol(x): if isinstance(x, Symbol): return x return Constant(x)
[docs]def evaluate(x): """Evaluates a symbol or N-dimensional array of symbols. Parameters ---------- x : :class:`.Symbol` or array-like A symbolic expression or an array of symbolic expressions. Returns ------- out : float, complex, :class:`numpy.ndarray` A single value for the evaluated expression if `x` is not array-like, otherwise an array of the evaluated expressions. """ if is_iterable(x): y = np.array(x, dtype=np.complex128) if not np.any(y.imag): # purely real symbols in array with warnings.catch_warnings(): # suppress 'casting to float discards imag part' warning # as we know that all imag parts are zero here anyway warnings.simplefilter("ignore", category=np.ComplexWarning) y = np.array(y, dtype=np.float64) return y if isinstance(x, Symbol): return x.eval() # If not a symbol then just return x directly return x
[docs]class Symbol(abc.ABC): __add__ = OPERATORS["__add__"] __sub__ = OPERATORS["__sub__"] __mul__ = OPERATORS["__mul__"] __radd__ = OPERATORS["__radd__"] __rsub__ = OPERATORS["__rsub__"] __rmul__ = OPERATORS["__rmul__"] __pow__ = OPERATORS["__pow__"] __truediv__ = OPERATORS["__truediv__"] __rtruediv__ = OPERATORS["__rtruediv__"] __floordiv__ = OPERATORS["__floordiv__"] __rfloordiv__ = OPERATORS["__rfloordiv__"] __matmul__ = OPERATORS["__matmul__"] __neg__ = OPERATORS["__neg__"] __abs__ = FUNCTIONS["abs"] __pos__ = FUNCTIONS["pos"] conjugate = FUNCTIONS["conj"] conj = FUNCTIONS["conj"] exp = FUNCTIONS["exp"] sin = FUNCTIONS["sin"] arcsin = FUNCTIONS["arcsin"] cos = FUNCTIONS["cos"] arccos = FUNCTIONS["arccos"] tan = FUNCTIONS["tan"] arctan = FUNCTIONS["arctan"] arctan2 = FUNCTIONS["arctan2"] sqrt = FUNCTIONS["sqrt"] radians = FUNCTIONS["radians"] degrees = FUNCTIONS["degrees"] deg2rad = FUNCTIONS["deg2rad"] rad2deg = FUNCTIONS["rad2deg"] @property def real(self): return FUNCTIONS["real"](self) @property def imag(self): return FUNCTIONS["imag"](self) def __eq__(self, obj): """Inheriting classes should implement __symeq__ and do any symbol specific equality checks there. This top level handles initial n-arg conversion before symbolic comparison. """ # Need to convert any symbolic expressions (Functions) # to nary form for equality checks if isinstance(obj, Function) and not obj._is_narg_expression_tree: obj = obj.to_nary_add_mul() if isinstance(self, Function) and not self._is_narg_expression_tree: self = self.to_nary_add_mul() return self.__symeq__(obj) def __symeq__(self, obj): return id(self) == id(obj) def __float__(self): v = self.eval() if np.isscalar(v): return float(v) else: raise TypeError(f"Can't cast {type(v)} ({v}) into a single float value") def __complex__(self): v = self.eval() if np.isscalar(v): return complex(v) else: raise TypeError(f"Can't cast {type(v)} ({v}) into a single complex value") def __int__(self): v = self.eval() if np.isscalar(v): return int(v) else: raise TypeError(f"Can't cast {type(v)} into a single int value") @property def value(self): return self.eval() def __str__(self): return display(self) def __repr__(self): return f"<Symbolic='{display(self)}' @ {hex(id(self))}>" @property def is_changing(self): """Returns True if one of the arguements of this symbolic object is varying whilst a :class:`` is running.""" res = False if hasattr(self, "op"): res = any([_.is_changing for _ in self.args]) elif hasattr(self, "parameter"): res = self.parameter.is_tunable or self.parameter.is_changing return res
[docs] def parameters(self, memo=None): """Returns all the parameters that are present in this symbolic statement.""" if memo is None: memo = set() if hasattr(self, "op"): for _ in self.args: _.parameters(memo) elif hasattr(self, "parameter"): memo.add(self) return list(memo)
[docs] def all(self, predicate, memo=None): """Returns all the arguments that are present in this symbolic statement which satisify the predicate. Parameters ---------- predicate : callable Method which takes in an argument and returns True if it matches. """ if memo is None: memo = set() if hasattr(self, "op"): for _ in self.args: _.all(predicate, memo) elif predicate(self): memo.add(self) return list(memo)
[docs] def changing_parameters(self): p = np.array(self.parameters()) return list(p[list(map(lambda x: x.is_changing, p))])
[docs] def to_sympy(self): """Converts a Finesse symbolic expression into a Sympy expression. Warning: for large functions this can be quite slow. """ return finesse2sympy(self)
[docs] def sympy_simplify(self): """Converts this expression into a Sympy symbol.""" refs = { _.name: _ for _ in self.parameters() } # get a list of symbols we're using sympy = finesse2sympy(self) return sympy2finesse(sympy.simplify(), refs)
[docs] def collect(self): """Collects like terms in the expressions.""" return collect(self)
[docs] def expand(self): """Performs a basic expansion of the symbolic expression.""" return expand(self)
[docs] def expand_symbols(self): """A method that expands any symbolic parameter references that are themselves symbolic. This can be used to get an expression that only depends on references that are numeric. Examples -------- >>> import finesse >>> model = finesse.Model() >>> model.parse( ... ''' ... var d 300 ... var c 6000 ... var b c+d ... var a b+1 ... ''' ... ) >>> model.a.value.value.expand_symbols() <Symbolic='((c.value+d.value)+1)' @ 0x7faa4d351c10> Parameters ---------- sym : Symbolic Symbolic equation to expand """ def process(p): if p.parameter.is_symbolic: return p.parameter.value else: return p def _expand(sym): params = sym.parameters() if len(params) == 0: return None elif not all(p.parameter.is_symbolic for p in params): return None else: return sym.eval(subs={p: process(p) for p in params}) sym = self while True: res = _expand(sym) if res is None: return sym else: sym = res
[docs] def to_binary_add_mul(self): """Converts a symbolic expression to use binary forms of operator.add and operator.mul operators, rather than the n-ary operator_add and operator_mul. Returns ------- Symbol """ if hasattr(self, "op"): return self.op(*(_.to_binary_add_mul() for _ in self.args)) else: return self
[docs] def to_nary_add_mul(self): """Converts a symbolic expression to use n-ary forms of operator_add and operator_mul operators, rather than the binary-ary operator.add and operator.mul. Returns ------- Symbol """ if hasattr(self, "op"): with simplification(allow_flagged=True): # calling to_nary_add_mul here as it's basically the same but # just won't keep opening a new context manager each time return self.op(*(_.to_nary_add_mul() for _ in self.args)) else: return self
[docs]class Matrix(Symbol): """A Matrix symbol."""
[docs] def __init__(self, name): self.name = str(name)
def __hash__(self): return hash((self.name,))
[docs] def eval(self): return self
[docs]class Constant(Symbol): """Defines a constant symbol that can be used in symbolic math. Parameters ---------- value : float, int Value of constant name : str, optional Name of the constant to use when printing """
[docs] def __init__(self, value, name=None): self.__value = value self.__name = name
def __str__(self): return self.__name or str(self.__value) def __repr__(self): return str(self.__name or self.__value) def __symeq__(self, obj): if isinstance(obj, Constant): return obj.value == self.value else: return obj == self.value def __hash__(self): """Constant hash. This is used by the tokenizer. Constants are by definition immutable. """ # Add the class to reduce chance of hash collisions. return hash((type(self), self.value))
[docs] def eval(self, subs=None, **kwargs): """Evaluate this constant. If a substitution is available the value of that will be used instead of `self.value` """ if subs and self in subs: return subs[self] elif hasattr(self.__value, "eval"): return self.__value.eval() else: return self.__value
@property def is_named(self): """Was this constant given a specific name.""" return self.__name is not None @property def name(self): return str(self.value) if self.__name is None else self.__name
# Constants. # NOTE: The keys here are used by the parser to recognise constants in kat script. CONSTANTS = { "pi": Constant(constants.PI, name="π"), "c0": Constant(constants.C_LIGHT, name="c"), }
[docs]class Resolving(Symbol): """A special symbol that represents a symbol that is not yet resolved. This is used in the parser to support self-referencing parameters. An error is thrown if the value is attempted to be read. """
[docs] def eval(self, **kwargs): raise RuntimeError( "an attempt has been made to read the value of a resolving symbol (hint: " "symbols should not be evaluated until parsing has fully finished)" )
@property def name(self): return "RESOLVING"
[docs]class Variable(Symbol): """Makes a variable symbol that can be used in symbolic math. Values must be substituted in when evaluating an expression. Examples -------- Using some variables to make an expression and evaluating it: >>> import numpy as np >>> x = Variable('x') >>> y = Variable('y') >>> z = 4*x**2 - np.cos(y) >>> print(f"{z} = {z.eval(subs={x:2, y:3})} : x={2}, y={3}") (4*x**2-y) = 13 : x=2, y=3 Parameters ---------- value : float, int Value of constant name : str, optional Name of the constant to use when printing """
[docs] def __init__(self, name): if name is None: raise ValueError("Name must be provided") self.__name = str(name)
def __hash__(self): return hash((Variable, self.name)) @property def name(self): return self.__name
[docs] def eval(self, subs=None, keep=None, **kwargs): """Evaluates this variable and returns either itself or a numeric value. Parameters ---------- subs : dict Dictionary of numeric values to substitute for variables keep : iterable, str A collection of names of variables to keep as variables when evaluating. """ if keep: if self.name == keep: return self elif is_iterable(keep): if self.name in keep or self in keep: return self if subs: if self.name in subs: return subs[self.name] if self in subs: return subs[self] return self
[docs]class Function(Symbol): """This is a symbol to represent a mathematical function. This could be a simple addition, or a more complicated multi-argument function. It supports creating new mathematical operations:: import math import cmath cos = lambda x: finesse.symbols.Function("cos", math.cos, x) sin = lambda x: finesse.symbols.Function("sin", math.sin, x) atan2 = lambda y, x: finesse.symbols.Function("atan2", math.atan2, y, x) Complex math can also be used:: import numpy as np angle = lambda x: finesse.symbols.Function("angle", np.angle, x) print(f"{angle(1+1j)} = {angle(1+1j).eval()}") Parameters ---------- name : str The operation name. This is used for dumping operations to kat script. operation : callable The function to pass the arguments of this operation to. Other Parameters ---------------- *args The arguments to pass to `operation` during a call. """
[docs] def __init__(self, name, operation, *args): global _SIMPLIFICATION_ENABLED self.__simplified_expr_tree = _SIMPLIFICATION_ENABLED self.name = str(name) self.args = list(as_symbol(_) for _ in args) self.op = operation
[docs] def eval(self, **kwargs): """Evaluates the operation. Parameters ---------- subs : dict, optional Parameter substitutions can be given via an optional ``subs`` dict (mapping parameters to substituted values). keep : iterable, str A collection of names of variables to keep as variables when evaluating. Notes ----- A division by zero will return a NaN, rather than raise an exception. Returns ------- result : number or array-like The single-valued result of evaluation of the operation (if no substitutions given, or all substitutions are scalar-valued). Otherwise, if any parameter substitution was a :class:`numpy.ndarray`, then a corresponding array of results. """ try: args = tuple(_.eval(**kwargs) for _ in self.args) return self.op(*args) except ZeroDivisionError: return np.nan # return except TypeError as ex: msg = f"Whilst evaluating {self.op}{self.args} this error was raised:\n` {ex}`" for _ in self.args: if _.value is None: msg += f"\nHint: {_} is `None` make sure it is defined before being used." if _.parameter.full_name == "fsig.f": msg += " if using KatScript use `fsig(f)` before this line." raise FinesseException(msg)
@property def _is_narg_expression_tree(self): """Was this expressions built with the global `_SIMPLIFICATION_ENABLED` set to `False`.""" return self.__simplified_expr_tree @property def contains_unresolved_symbol(self): """Whether the operation contains any unresolved symbols. :getter: Returns true if any symbol in the operation is an instance of :class:`.Resolving`, false otherwise. Read-only. """ for arg in self.args: if isinstance(arg, Resolving): return True if isinstance(arg, Function): return arg.contains_unresolved_symbol return False def __symeq__(self, obj): if isinstance(obj, Function): if self.op == obj.op: return all([a == b for a, b in zip(self.args, obj.args)]) else: return False else: return False def __hash__(self): return hash((Function, self.op, *(hash(_) for _ in self.args)))
[docs]class LazySymbol(Symbol): """A generic way to make some lazily evaluated symbol. The value is dependant on a lambda function and some arbitrary arguments which will be called when the symbol is evaluated. Parameters ---------- name : str Human readable string name for the symbol function : callable Function to call when evaluating the symbol *args : objects Arguments to pass to `function` when evaluating Examples -------- >>> a = LazyVariable('a', lambda x: x**2, 10) >>> print(a) <Symbolic='(a*10)' @ 0x7fd6587e6760> >>> print((a*10).eval()) 1000 """
[docs] def __init__(self, name, function, *args): self.__name = name self.function = function self.args = args
[docs] def eval(self, **kwargs): return self.function(*self.args)
@property def name(self): return self.__name
[docs]def collect(y): if hasattr(y, "op"): if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if not hasattr(y, "op"): return y with simplification(allow_flagged=True): if y.op is operator_add: args = {} out = 0 for x in y.args: c, t = coefficient_and_term(x) if t is None: # just a coefficient/constant out += x else: if t in args: args[t] += c else: args[t] = c for k, v in args.items(): out += k * v return out elif len(y.args) > 0: # y is some other operator, which may have args that need collecting return y.op(*(collect(_) for _ in y.args)) else: return y else: return y
[docs]def coefficient_and_term(y): try: if y.op is operator_mul: if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees args = reduce_mul_args(y.args) if type(args[0]) is not Constant: coeff = Constant(1) term = np.prod(args) else: coeff = args[0] term = np.prod(args[1:]) return coeff, term else: return Constant(1), y except AttributeError: # Not an operator... if type(y) is Constant: return y, None else: return Constant(1), y
[docs]def expand_mul(y): with simplification(allow_flagged=True): if not hasattr(y, "op"): return y elif not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if y.op is operator_mul: terms = np.array([1], dtype="O") for x in y.args: try: if x.op is operator_add: terms = np.outer(terms, x.args) else: terms *= x except AttributeError: terms *= x res = np.sum(terms) return res elif y.op is operator_add: return sum(expand_mul(_) for _ in y.args) else: return y
[docs]def is_integer(n): """Checks if `n` is an integer. Parameters ---------- n : str, float Input to check """ try: float(n) except ValueError: return False else: return float(n).is_integer()
[docs]def expand_pow(y): with simplification(allow_flagged=True): if not hasattr(y, "op"): return y elif not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if y.op is operator.pow: z = 1 exp = y.args[1] base = y.args[0] try: if base.op is operator_mul: for _ in base.args: z *= _**exp return z elif base.op is operator_add: if type(exp) is Constant and is_integer(exp.value): n = int(exp.value) if n == 0: return 1 terms = np.array([1], dtype="O") for i in range(abs(n)): terms = np.outer(terms, base.args) if n > 0: return terms.sum().collect() else: return (terms.sum().collect()) ** -1 else: return y else: return y except AttributeError: return y elif y.op is operator_add: return sum(expand_pow(_) for _ in y.args) elif y.op is operator_mul: return np.prod(tuple(expand_pow(_) for _ in y.args)) else: return y
[docs]def expand(y): with simplification(allow_flagged=True): if not isinstance(y, Function): return y # Nothing to expand else: if not y._is_narg_expression_tree: y = y.to_nary_add_mul() # simplification happens on nary trees if not isinstance(y, Function): return y # Nothing to expand if y.op is operator_mul: # Run through and expand any pows before the # full mul expansion y = Function(y.name, y.op, *[expand_pow(_) for _ in y.args]) y = expand_mul(y) # Nothing left to expand as no add operator if not isinstance(y, Function): return y elif y.op is operator_mul: return y elif y.op is operator.pow: y = expand_pow(y) # Nothing left to expand as no add operator if y.op is operator.pow: return y if y.op is not operator_add: # some other type of operator expand it's args if len(y.args) > 0: return y.op(*(expand(_) for _ in y.args)) else: return y else: # We have a summation of terms from the initial expansion # so expand each term out = [] for x in y.args: try: if x.op is operator_mul: x = Function(x.name, x.op, *[expand_pow(_) for _ in x.args]) z = expand_mul(x) elif x.op is operator.pow: z = expand_pow(x) else: z = x if hasattr(z, "op") and z.op is operator_add: out.extend(z.args) else: out.append(z) except AttributeError: out.append(x) z = Function(y.name, y.op, *out) if y == z: return z else: return expand(z)