"""Symbolic manipulations (`expand`, `collect`, etc.) are based on the book:
Cohen JS. Computer algebra and symbolic computation: mathematical methods. First
edition, 2003.
"""
import warnings
import abc
import numpy as np
from finesse import constants
from finesse.utilities import is_iterable
from finesse.exceptions import FinesseException
from contextlib import contextmanager
from functools import reduce
import operator
import logging
LOGGER = logging.getLogger(__name__)
MAKE_LOP = lambda name, opfn: lambda self, other: Function(name, opfn, self, other)
MAKE_ROP = lambda name, opfn: lambda self, other: Function(name, opfn, other, self)
# Default no simplification happens, must wrap symbol code that needs it
# with the context manager below. When flagged symbols will attempt to simplify
# themselves, using a standardised ordering of arguments. Only +, *, and ** operators
# are used in simplified symbols which allow various symbolic simplification
# algorithms to be used.
_SIMPLIFICATION_ENABLED = False  # Global flag to state if simplification will happen
[docs]@contextmanager
def simplification(allow_flagged=False):
    """When used any symbol operations will use simplification methods."""
    global _SIMPLIFICATION_ENABLED
    if allow_flagged and _SIMPLIFICATION_ENABLED:
        yield
    else:
        if _SIMPLIFICATION_ENABLED is True:
            raise RuntimeError("Simplification has already been enabled")
        try:
            _SIMPLIFICATION_ENABLED = True
            yield
        finally:
            _SIMPLIFICATION_ENABLED = False 
[docs]def base_exponent(y):
    try:
        if y.op is operator.pow:
            return y.args[0], y.args[1]
        else:
            return y, Constant(1)
    except AttributeError:
        return y, Constant(1) 
[docs]def operator_add(*args):
    return reduce(operator.add, args) 
[docs]def operator_mul(*args):
    return reduce(operator.mul, args) 
[docs]def operator_sub(*args):
    return reduce(operator.sub, args) 
[docs]def MAKE_simplify_add(dir):
    def simplify_add(self, other):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            if dir == "LOP":
                return Function("+", operator.add, self, other)
            else:
                return Function("+", operator.add, other, self)
        if np.all(self == 0):
            return other
        elif np.all(other == 0):
            return self
        elif type(self) is Constant and type(other) is Constant:
            if self.is_named and other.is_named:
                if self == other:
                    return 2 * self
                else:
                    return self + other
            elif self.is_named ^ other.is_named:
                if dir == "LOP":
                    return Function("+", operator_add, self, other)
                else:
                    return Function("+", operator_add, other, self)
            else:
                return Constant(self.value + other.value)
        else:
            def process(a):
                try:
                    if a.op is operator_add:
                        a = a.args
                    else:
                        a = (a,)
                except AttributeError:
                    a = (a,)
                return a
            a = process(self)
            b = process(other)
            def sort_key(a):
                try:
                    # chr(58) is the character after 9 so
                    # numerical numbers come first.
                    if a.op is operator_add:
                        return chr(58)
                    elif a.op is operator.pow:
                        return chr(58) + str(a.args[0]) + str(a.args[1])
                    else:
                        return str(a)
                except AttributeError:
                    return a.name
            f = Function("+", operator_add, *a, *b)
            f.args.sort(key=sort_key)
            return f
    return simplify_add 
[docs]def reduce_mul_args(args):
    """Sorts and reduces a multiply operation arguments.
    Collect constants and sort variables by their `str`.
    """
    def _reduce_mul_args(a, b):
        b = as_symbol(b)  # need this in case arg is a float/int type
        # Assumes first element is a numeric constant and
        # store all constants there
        if type(b) is Constant:
            if len(a) >= 1:
                if b.is_named:
                    a[0].append(b)
                else:
                    a[0][0] = Constant(a[0][0].value * b.value)
            else:
                a[0].append(b)
        elif type(b) is Matrix:
            a[1].append(b)
        else:
            a[0].append(b)
        return a
    def _reduce_mul_args_pows(a, b):
        base, exp = base_exponent(b)
        if np.all(a[-1][0] == base):
            a[-1][1] += exp
        else:
            a.append([base, exp])
        return a
    def sort_key(a):
        try:
            # chr(58) is the character after 9 so
            # numerical numbers come first.
            if a.op is operator_add:
                return chr(58)
            elif a.op is operator.pow:
                return chr(58) + str(a.args[0]) + str(a.args[1])
            else:
                return str(a)
        except AttributeError:
            return a.name
    if len(args) > 1:
        scalars, matrices = reduce(_reduce_mul_args, args, ([Constant(1)], []))
        if np.all(scalars[0] == 1):
            if len(scalars) == 1:
                return [1]
            else:
                scalars.pop(0)
        # Sort arguments alphabetically using str form of symbols
        scalars.sort(key=sort_key)
        # Now it's ordered, powers will be grouped
        # together so we can reduce them too
        pow_args = reduce(
            _reduce_mul_args_pows, scalars[1:], [list(base_exponent(scalars[0]))]
        )
        args = list(b**e for b, e in pow_args)
        args.extend(matrices)
    return args 
[docs]def MAKE_simplify_mul(dir):
    def simplify_mul(self, other):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            if dir == "LOP":
                return Function("*", operator.mul, self, other)
            else:
                return Function("*", operator.mul, other, self)
        if np.all(self == 0):
            return 0
        elif np.all(other == 0):
            return 0
        elif np.all(self == 1):
            return other
        elif np.all(other == 1):
            return self
        else:
            def process(a):
                try:
                    if a.op is operator_mul:
                        a = a.args
                    else:
                        a = (a,)
                except AttributeError:
                    a = (a,)
                return a
            a = process(self)
            b = process(other)
            if dir == "LOP":
                args = [*a, *b]
            else:
                args = [*b, *a]
            args = reduce_mul_args(args)
            if len(args) > 1:
                return Function("*", operator_mul, *args)
            else:
                return args[0]
    return simplify_mul 
[docs]def MAKE_simplify_sub(dir):
    _simplify_sub = MAKE_simplify_add(dir)
    def simplify_sub(self, other):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            if dir == "LOP":
                return Function("-", operator.sub, self, other)
            else:
                return Function("-", operator.sub, other, self)
        if dir == "LOP":
            return _simplify_sub(self, -other)
        else:
            return _simplify_sub(-self, other)
    return simplify_sub 
[docs]def MAKE_simplify_neg():
    def simplify_neg(self):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            return Function("-", operator.neg, self)
        else:
            return -1 * self
    return simplify_neg 
[docs]def MAKE_simplify_pow():
    def simplify_pow(self, exp):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            return Function("**", operator.pow, self, exp)
        if np.all(exp == 0):
            return Constant(1)
        elif np.all(exp == 1):
            return self
        elif (
            type(self) is Constant
            and type(exp) is Constant
            and not (self.is_named ^ exp.is_named)
        ):
            return Constant(self.value**exp.value)
        else:
            return Function("**", operator.pow, self, exp)
    return simplify_pow 
[docs]def MAKE_LOP_simplify_truediv():
    def simplify_truediv(self, other):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            return Function("/", operator.truediv, self, other)
        if np.all(self == 0):
            return 0
        elif np.all(other == 0):
            raise ZeroDivisionError()
        elif isinstance(self, Function) and self.op is operator.neg:
            return -self.args[0] * Function("**", operator.pow, other, -1)
        elif isinstance(other, Function) and other.op is operator.neg:
            return -self * Function("**", operator.pow, other.args[0], -1)
        else:
            return self * Function("**", operator.pow, other, -1)
    return simplify_truediv 
[docs]def MAKE_ROP_simplify_truediv():
    def simplify_truediv(self, other):
        global _SIMPLIFICATION_ENABLED
        if not _SIMPLIFICATION_ENABLED:
            return Function("/", operator.truediv, other, self)
        if np.all(other == 0):
            return 0
        elif self == 0:
            raise ZeroDivisionError()
        elif isinstance(self, Function) and self.op is operator.neg:
            return -other * Function("**", operator.pow, self.args[0], -1)
        elif isinstance(other, Function) and other.op is operator.neg:
            return -other.args[0] * Function("**", operator.pow, self, -1)
        else:
            return other * Function("**", operator.pow, self, -1)
    return simplify_truediv 
# Supported operators.
OPERATORS = {
    "__add__": MAKE_simplify_add("LOP"),
    "__sub__": MAKE_simplify_sub("LOP"),
    "__mul__": MAKE_simplify_mul("LOP"),
    "__radd__": MAKE_simplify_add("ROP"),
    "__rsub__": MAKE_simplify_sub("ROP"),
    "__rmul__": MAKE_simplify_mul("ROP"),
    "__neg__": MAKE_simplify_neg(),
    "__pow__": MAKE_simplify_pow(),
    "__truediv__": MAKE_LOP_simplify_truediv(),
    "__rtruediv__": MAKE_ROP_simplify_truediv(),
    "__floordiv__": MAKE_LOP("//", operator.floordiv),
    "__rfloordiv__": MAKE_ROP("//", operator.floordiv),
    "__matmul__": MAKE_LOP("@", operator.matmul),
}
# Maps function names to actual functions called,
# this is used for lambdifying symbolics as you
# can't grab the underlying function from the lambdas
# stored in FUNCTIONS below
PYFUNCTION_MAP = {
    "abs": operator.abs,
    "neg": operator.neg,
    "pos": operator.pos,
    "pow": operator.pow,
    "conj": np.conj,
    "real": np.real,
    "imag": np.imag,
    "exp": np.exp,
    "log10": np.log10,
    "log": np.log,
    "sin": np.sin,
    "arcsin": np.arcsin,
    "cos": np.cos,
    "arccos": np.arccos,
    "tan": np.tan,
    "arctan": np.arctan,
    "arctan2": np.arctan2,
    "sqrt": np.sqrt,
    "std": np.std,
    "sum": np.sum,
    "dot": np.dot,
    "radians": np.radians,
    "degrees": np.degrees,
    "deg2rad": np.deg2rad,
    "rad2deg": np.rad2deg,
    "arange": np.arange,
    "linspace": np.linspace,
    "logspace": np.logspace,
    "geomspace": np.geomspace,
}
# Built-in symbolic functions: maps string names of functions to acutal functions
FUNCTIONS = {
    "abs": lambda x: Function("abs", operator.abs, x),
    "neg": lambda x: Function("neg", operator.neg, x),
    "pos": lambda x: Function("pos", operator.pos, x),
    "pow": lambda x: Function("pow", operator.pow, x),
    "conj": lambda x: Function("conj", np.conj, x),
    "real": lambda x: Function("real", np.real, x),
    "imag": lambda x: Function("imag", np.imag, x),
    "exp": lambda x: Function("exp", np.exp, x),
    "log": lambda x: Function("log", np.log, x),
    "log10": lambda x: Function("log10", np.log10, x),
    "sin": lambda x: Function("sin", np.sin, x),
    "arcsin": lambda x: Function("arcsin", np.arcsin, x),
    "cos": lambda x: Function("cos", np.cos, x),
    "arccos": lambda x: Function("arccos", np.arccos, x),
    "tan": lambda x: Function("tan", np.tan, x),
    "arctan": lambda x: Function("arctan", np.arctan, x),
    "arctan2": lambda y, x: Function("arctan2", np.arctan2, y, x),
    "sqrt": lambda x: Function("sqrt", np.sqrt, x),
    "std": lambda x: Function("std", np.std, x),
    "sum": lambda x: Function("sum", np.sum, x),
    "dot": lambda x, y: Function("dot", np.dot, x, y),
    "radians": lambda x: Function("radians", np.radians, x),
    "degrees": lambda x: Function("degrees", np.degrees, x),
    "deg2rad": lambda x: Function("deg2rad", np.deg2rad, x),
    "rad2deg": lambda x: Function("rad2deg", np.rad2deg, x),
    "arange": lambda a, b, c: Function("arange", np.arange, float(a), float(b), int(c)),
    "linspace": lambda a, b, c: Function(
        "linspace", np.linspace, float(a), float(b), int(c)
    ),
    "logspace": lambda a, b, c: Function(
        "logspace", np.logspace, float(a), float(b), int(c)
    ),
    "geomspace": lambda a, b, c: Function(
        "geomspace", np.geomspace, float(a), float(b), int(c)
    ),
}
op_repr = {
    operator_add: lambda *args: "({})".format("+".join(args)),
    operator.add: "({}+{})".format,
    operator.sub: lambda *args: "({})".format("-".join(args)),
    operator_mul: lambda *args: "*".join(args),
    operator.mul: "({}*{})".format,
    operator.pow: "({})**({})".format,
    operator.truediv: "{}/{}".format,
    operator.floordiv: "{}//{}".format,
    operator.mod: "({}%{})".format,
    operator.matmul: "({}@{})".format,
    operator.neg: "-{}".format,
    operator.pos: "+{}".format,
    operator.abs: "abs({})".format,
    np.conj: "conj({})".format,
    np.sqrt: "sqrt({})".format,
}
[docs]def display(a):
    """For a given Symbol this method will return a human readable string representing
    the various operations it contains.
    Parameters
    ----------
    a : :class:`.Symbol`
        Symbol to print
    Returns
    -------
    String form of Symbol
    """
    if hasattr(a, "op"):
        # Check if operation has a predefined string format
        if a.op in op_repr:
            sop = op_repr[a.op]
        else:  # if not just treat it as a function
            sop = (a.op.__name__ + "(" + ("{}," * len(a.args)).rstrip(",") + ")").format
        sargs = (display(_) for _ in a.args)
        return sop(*sargs).replace("-1*", "-").replace("+-", "-").replace("*1/", "/")
    elif hasattr(a, "name"):  # Anything with a name attribute just display that
        return a.name
    elif type(a) is Symbol:
        return f"<Symbol @ {hex(id(a))}>"
    else:
        return str(a) 
[docs]def finesse2sympy(expr, iter_num=0):
    """
    Notes
    -----
    It might be common for this this function to throw a NotImplementedError.
    This function maps, by hand, various operator and numpy functions to sympy.
    If you come across this error, you'll need to update the if-statement to
    include the missing operations. Over time this should get fixed for most
    use cases.
    """
    import sympy
    from finesse.parameter import ParameterRef
    iter_num += 1
    if isinstance(expr, Constant):
        return expr.value
    elif isinstance(expr, ParameterRef):
        return sympy.Symbol(expr.name)
    elif isinstance(expr, Function):
        sympy_args = [finesse2sympy(arg, iter_num) for arg in expr.args]
        if expr.op == operator_mul or expr.op == operator.mul:
            op = sympy.Mul
        elif expr.op == operator_add or expr.op == operator.add:
            op = sympy.Add
        elif expr.op == operator.truediv:
            op = lambda a, b: sympy.Mul(a, sympy.Pow(b, -1))
        elif expr.op == operator.pow:
            op = sympy.Pow
        elif expr.op == operator.sub:
            op = lambda x, y: sympy.Add(x, -y)
        elif expr.op == np.conj:
            op = sympy.conjugate
        elif expr.op == np.radians:
            op = sympy.rad
        elif expr.op == np.exp:
            op = sympy.exp
        elif expr.op == np.cos:
            op = sympy.cos
        elif expr.op == np.sin:
            op = sympy.sin
        elif expr.op == np.tan:
            op = sympy.tan
        elif expr.op == np.sqrt:
            op = sympy.sqrt
        elif expr.op == operator.abs:
            op = sympy.Abs
        elif expr.op == operator.neg:
            op = lambda x: sympy.Mul(-1, x)
        elif expr.op == np.imag:
            op = sympy.im
        elif expr.op == np.real:
            op = sympy.re
        else:
            try:
                op = getattr(sympy, expr.op.__name__)
            except AttributeError:
                raise NotImplementedError(
                    f"undefined Function {expr.op} in {expr}. {finesse2sympy} needs to be updated."
                )
        return op(*sympy_args)
    else:
        raise NotImplementedError(
            f"{expr} undefined. {finesse2sympy} needs to be updated."
        ) 
[docs]def sympy2finesse(expr, symbol_dict={}, iter_num=0):
    import sympy
    iter_num += 1
    if isinstance(expr, sympy.Mul):
        return np.product(
            [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args]
        )
    elif isinstance(expr, sympy.Add):
        return np.sum(
            [sympy2finesse(arg, symbol_dict, iter_num=iter_num) for arg in expr.args]
        )
    elif isinstance(expr, sympy.conjugate):
        return np.conj(sympy2finesse(*expr.args, symbol_dict))
    elif isinstance(expr, sympy.exp):
        return np.exp(sympy2finesse(*expr.args, symbol_dict))
    elif isinstance(expr, sympy.Pow):
        return np.power(
            sympy2finesse(expr.args[0], symbol_dict),
            sympy2finesse(expr.args[1], symbol_dict),
        )
    elif (
        expr.is_NumberSymbol
    ):  # sympy class for named symbols (eg Pi, golden ratio, ...)
        if str(expr) == "pi":
            return CONSTANTS["pi"]
        else:
            return complex(expr)
    elif expr.is_number:
        if expr.is_integer:
            return int(expr)
        elif expr.is_real:
            return float(expr)
        else:
            return complex(expr)
    elif expr.is_symbol:
        return symbol_dict[str(expr)]
    else:
        raise Exception(f"{expr} undefined") 
simplify_symbolic_numpy = np.vectorize(
    lambda x: collect(expand(x)) if isinstance(x, Symbol) else x, otypes="O"
)
[docs]def np_eval_symbolic_numpy(a, *keep):
    if isinstance(a, Symbol):
        return a.eval(keep=keep)
    else:
        return a 
__eval_symbolic_numpy = np.vectorize(np_eval_symbolic_numpy, otypes="O")
[docs]def eval_symbolic_numpy(a, *keep):
    return __eval_symbolic_numpy(a, *keep) 
[docs]def as_symbol(x):
    if isinstance(x, Symbol):
        return x
    return Constant(x) 
[docs]def evaluate(x):
    """Evaluates a symbol or N-dimensional array of symbols.
    Parameters
    ----------
    x : :class:`.Symbol` or array-like
        A symbolic expression or an array of symbolic expressions.
    Returns
    -------
    out : float, complex, :class:`numpy.ndarray`
        A single value for the evaluated expression if `x` is not
        array-like, otherwise an array of the evaluated expressions.
    """
    if is_iterable(x):
        y = np.array(x, dtype=np.complex128)
        if not np.any(y.imag):  # purely real symbols in array
            with warnings.catch_warnings():
                # suppress 'casting to float discards imag part' warning
                # as we know that all imag parts are zero here anyway
                warnings.simplefilter("ignore", category=np.ComplexWarning)
                y = np.array(y, dtype=np.float64)
        return y
    if isinstance(x, Symbol):
        return x.eval()
    # If not a symbol then just return x directly
    return x 
[docs]class Symbol(abc.ABC):
    __add__ = OPERATORS["__add__"]
    __sub__ = OPERATORS["__sub__"]
    __mul__ = OPERATORS["__mul__"]
    __radd__ = OPERATORS["__radd__"]
    __rsub__ = OPERATORS["__rsub__"]
    __rmul__ = OPERATORS["__rmul__"]
    __pow__ = OPERATORS["__pow__"]
    __truediv__ = OPERATORS["__truediv__"]
    __rtruediv__ = OPERATORS["__rtruediv__"]
    __floordiv__ = OPERATORS["__floordiv__"]
    __rfloordiv__ = OPERATORS["__rfloordiv__"]
    __matmul__ = OPERATORS["__matmul__"]
    __neg__ = OPERATORS["__neg__"]
    __abs__ = FUNCTIONS["abs"]
    __pos__ = FUNCTIONS["pos"]
    conjugate = FUNCTIONS["conj"]
    conj = FUNCTIONS["conj"]
    exp = FUNCTIONS["exp"]
    sin = FUNCTIONS["sin"]
    arcsin = FUNCTIONS["arcsin"]
    cos = FUNCTIONS["cos"]
    arccos = FUNCTIONS["arccos"]
    tan = FUNCTIONS["tan"]
    arctan = FUNCTIONS["arctan"]
    arctan2 = FUNCTIONS["arctan2"]
    sqrt = FUNCTIONS["sqrt"]
    radians = FUNCTIONS["radians"]
    degrees = FUNCTIONS["degrees"]
    deg2rad = FUNCTIONS["deg2rad"]
    rad2deg = FUNCTIONS["rad2deg"]
    @property
    def real(self):
        return FUNCTIONS["real"](self)
    @property
    def imag(self):
        return FUNCTIONS["imag"](self)
    def __eq__(self, obj):
        """Inheriting classes should implement __symeq__ and do any symbol specific
        equality checks there.
        This top level handles initial n-arg conversion before symbolic comparison.
        """
        # Need to convert any symbolic expressions (Functions)
        # to nary form for equality checks
        if isinstance(obj, Function) and not obj._is_narg_expression_tree:
            obj = obj.to_nary_add_mul()
        if isinstance(self, Function) and not self._is_narg_expression_tree:
            self = self.to_nary_add_mul()
        return self.__symeq__(obj)
    def __symeq__(self, obj):
        return id(self) == id(obj)
    def __float__(self):
        v = self.eval()
        if np.isscalar(v):
            return float(v)
        else:
            raise TypeError(f"Can't cast {type(v)} ({v}) into a single float value")
    def __complex__(self):
        v = self.eval()
        if np.isscalar(v):
            return complex(v)
        else:
            raise TypeError(f"Can't cast {type(v)} ({v}) into a single complex value")
    def __int__(self):
        v = self.eval()
        if np.isscalar(v):
            return int(v)
        else:
            raise TypeError(f"Can't cast {type(v)} into a single int value")
    @property
    def value(self):
        return self.eval()
    def __str__(self):
        return display(self)
    def __repr__(self):
        return f"<Symbolic='{display(self)}' @ {hex(id(self))}>"
    @property
    def is_changing(self):
        """Returns True if one of the arguements of this symbolic object is varying
        whilst a :class:`` is running."""
        res = False
        if hasattr(self, "op"):
            res = any([_.is_changing for _ in self.args])
        elif hasattr(self, "parameter"):
            res = self.parameter.is_tunable or self.parameter.is_changing
        return res
[docs]    def parameters(self, memo=None):
        """Returns all the parameters that are present in this symbolic statement."""
        if memo is None:
            memo = set()
        if hasattr(self, "op"):
            for _ in self.args:
                _.parameters(memo)
        elif hasattr(self, "parameter"):
            memo.add(self)
        return list(memo) 
[docs]    def all(self, predicate, memo=None):
        """Returns all the arguments that are present in this symbolic statement which
        satisify the predicate.
        Parameters
        ----------
        predicate : callable
            Method which takes in an argument and returns True if it matches.
        """
        if memo is None:
            memo = set()
        if hasattr(self, "op"):
            for _ in self.args:
                _.all(predicate, memo)
        elif predicate(self):
            memo.add(self)
        return list(memo) 
[docs]    def changing_parameters(self):
        p = np.array(self.parameters())
        return list(p[list(map(lambda x: x.is_changing, p))]) 
[docs]    def to_sympy(self):
        """Converts a Finesse symbolic expression into a Sympy expression.
        Warning: for large functions this can be quite slow.
        """
        return finesse2sympy(self) 
[docs]    def sympy_simplify(self):
        """Converts this expression into a Sympy symbol."""
        refs = {
            _.name: _ for _ in self.parameters()
        }  # get a list of symbols we're using
        sympy = finesse2sympy(self)
        return sympy2finesse(sympy.simplify(), refs) 
[docs]    def collect(self):
        """Collects like terms in the expressions."""
        return collect(self) 
[docs]    def expand(self):
        """Performs a basic expansion of the symbolic expression."""
        return expand(self) 
[docs]    def expand_symbols(self):
        """A method that expands any symbolic parameter references that are themselves
        symbolic. This can be used to get an expression that only depends on references
        that are numeric.
        Examples
        --------
        >>> import finesse
        >>> model = finesse.Model()
        >>> model.parse(
        ...     '''
        ...     var d 300
        ...     var c 6000
        ...     var b c+d
        ...     var a b+1
        ...     '''
        ... )
        >>> model.a.value.value.expand_symbols()
        <Symbolic='((c.value+d.value)+1)' @ 0x7faa4d351c10>
        Parameters
        ----------
        sym : Symbolic
            Symbolic equation to expand
        """
        def process(p):
            if p.parameter.is_symbolic:
                return p.parameter.value
            else:
                return p
        def _expand(sym):
            params = sym.parameters()
            if len(params) == 0:
                return None
            elif not all(p.parameter.is_symbolic for p in params):
                return None
            else:
                return sym.eval(subs={p: process(p) for p in params})
        sym = self
        while True:
            res = _expand(sym)
            if res is None:
                return sym
            else:
                sym = res 
[docs]    def to_binary_add_mul(self):
        """Converts a symbolic expression to use binary forms of operator.add and
        operator.mul operators, rather than the n-ary operator_add and operator_mul.
        Returns
        -------
        Symbol
        """
        if hasattr(self, "op"):
            return self.op(*(_.to_binary_add_mul() for _ in self.args))
        else:
            return self 
[docs]    def to_nary_add_mul(self):
        """Converts a symbolic expression to use n-ary forms of operator_add and
        operator_mul operators, rather than the binary-ary operator.add and
        operator.mul.
        Returns
        -------
        Symbol
        """
        if hasattr(self, "op"):
            with simplification(allow_flagged=True):
                # calling to_nary_add_mul here as it's basically the same but
                # just won't keep opening a new context manager each time
                return self.op(*(_.to_nary_add_mul() for _ in self.args))
        else:
            return self  
[docs]class Matrix(Symbol):
    """A Matrix symbol."""
[docs]    def __init__(self, name):
        self.name = str(name) 
    def __hash__(self):
        return hash((self.name,))
[docs]    def eval(self):
        return self  
[docs]class Constant(Symbol):
    """Defines a constant symbol that can be used in symbolic math.
    Parameters
    ----------
    value : float, int
        Value of constant
    name : str, optional
        Name of the constant to use when printing
    """
[docs]    def __init__(self, value, name=None):
        self.__value = value
        self.__name = name 
    def __str__(self):
        return self.__name or str(self.__value)
    def __repr__(self):
        return str(self.__name or self.__value)
    def __symeq__(self, obj):
        if isinstance(obj, Constant):
            return obj.value == self.value
        else:
            return obj == self.value
    def __hash__(self):
        """Constant hash.
        This is used by the tokenizer. Constants are by definition immutable.
        """
        # Add the class to reduce chance of hash collisions.
        return hash((type(self), self.value))
[docs]    def eval(self, subs=None, **kwargs):
        """Evaluate this constant.
        If a substitution is available the value of that will be used instead of
        `self.value`
        """
        if subs and self in subs:
            return subs[self]
        elif hasattr(self.__value, "eval"):
            return self.__value.eval()
        else:
            return self.__value 
    @property
    def is_named(self):
        """Was this constant given a specific name."""
        return self.__name is not None
    @property
    def name(self):
        return str(self.value) if self.__name is None else self.__name 
# Constants.
# NOTE: The keys here are used by the parser to recognise constants in kat script.
CONSTANTS = {
    "pi": Constant(constants.PI, name="π"),
    "c0": Constant(constants.C_LIGHT, name="c"),
}
[docs]class Resolving(Symbol):
    """A special symbol that represents a symbol that is not yet resolved.
    This is used in the parser to support self-referencing parameters.
    An error is thrown if the value is attempted to be read.
    """
[docs]    def eval(self, **kwargs):
        raise RuntimeError(
            "an attempt has been made to read the value of a resolving symbol (hint: "
            "symbols should not be evaluated until parsing has fully finished)"
        ) 
    @property
    def name(self):
        return "RESOLVING" 
[docs]class Variable(Symbol):
    """Makes a variable symbol that can be used in symbolic math. Values must be
    substituted in when evaluating an expression.
    Examples
    --------
    Using some variables to make an expression and evaluating it:
    >>> import numpy as np
    >>> x = Variable('x')
    >>> y = Variable('y')
    >>> z = 4*x**2 - np.cos(y)
    >>> print(f"{z} = {z.eval(subs={x:2, y:3})} : x={2}, y={3}")
    (4*x**2-y) = 13 : x=2, y=3
    Parameters
    ----------
    value : float, int
        Value of constant
    name : str, optional
        Name of the constant to use when printing
    """
[docs]    def __init__(self, name):
        if name is None:
            raise ValueError("Name must be provided")
        self.__name = str(name) 
    def __hash__(self):
        return hash((Variable, self.name))
    @property
    def name(self):
        return self.__name
[docs]    def eval(self, subs=None, keep=None, **kwargs):
        """Evaluates this variable and returns either itself or a numeric value.
        Parameters
        ----------
        subs : dict
            Dictionary of numeric values to substitute for variables
        keep : iterable, str
            A collection of names of variables to keep as variables when
            evaluating.
        """
        if keep:
            if self.name == keep:
                return self
            elif is_iterable(keep):
                if self.name in keep or self in keep:
                    return self
        if subs:
            if self.name in subs:
                return subs[self.name]
            if self in subs:
                return subs[self]
        return self  
[docs]class Function(Symbol):
    """This is a symbol to represent a mathematical function. This could be a simple
    addition, or a more complicated multi-argument function.
    It supports creating new mathematical operations::
        import math
        import cmath
        cos   = lambda x: finesse.symbols.Function("cos", math.cos, x)
        sin   = lambda x: finesse.symbols.Function("sin", math.sin, x)
        atan2 = lambda y, x: finesse.symbols.Function("atan2", math.atan2, y, x)
    Complex math can also be used::
        import numpy as np
        angle = lambda x: finesse.symbols.Function("angle", np.angle, x)
        print(f"{angle(1+1j)} = {angle(1+1j).eval()}")
    Parameters
    ----------
    name : str
        The operation name. This is used for dumping operations to kat script.
    operation : callable
        The function to pass the arguments of this operation to.
    Other Parameters
    ----------------
    *args
        The arguments to pass to `operation` during a call.
    """
[docs]    def __init__(self, name, operation, *args):
        global _SIMPLIFICATION_ENABLED
        self.__simplified_expr_tree = _SIMPLIFICATION_ENABLED
        self.name = str(name)
        self.args = list(as_symbol(_) for _ in args)
        self.op = operation 
[docs]    def eval(self, **kwargs):
        """Evaluates the operation.
        Parameters
        ----------
        subs : dict, optional
            Parameter substitutions can be given via an optional
            ``subs`` dict (mapping parameters to substituted values).
        keep : iterable, str
            A collection of names of variables to keep as variables when
            evaluating.
        Notes
        -----
        A division by zero will return a NaN, rather than raise an exception.
        Returns
        -------
        result : number or array-like
            The single-valued result of evaluation of the operation (if no
            substitutions given, or all substitutions are scalar-valued). Otherwise,
            if any parameter substitution was a :class:`numpy.ndarray`, then a corresponding array
            of results.
        """
        try:
            args = tuple(_.eval(**kwargs) for _ in self.args)
            return self.op(*args)
        except ZeroDivisionError:
            return np.nan  # return
        except TypeError as ex:
            msg = f"Whilst evaluating {self.op}{self.args} this error was raised:\n`    {ex}`"
            for _ in self.args:
                if _.value is None:
                    msg += f"\nHint: {_} is `None` make sure it is defined before being used."
                    if _.parameter.full_name == "fsig.f":
                        msg += " if using KatScript use `fsig(f)` before this line."
            raise FinesseException(msg) 
    @property
    def _is_narg_expression_tree(self):
        """Was this expressions built with the global `_SIMPLIFICATION_ENABLED` set to
        `False`."""
        return self.__simplified_expr_tree
    @property
    def contains_unresolved_symbol(self):
        """Whether the operation contains any unresolved symbols.
        :getter: Returns true if any symbol in the operation is an instance
                 of :class:`.Resolving`, false otherwise. Read-only.
        """
        for arg in self.args:
            if isinstance(arg, Resolving):
                return True
            if isinstance(arg, Function):
                return arg.contains_unresolved_symbol
        return False
    def __symeq__(self, obj):
        if isinstance(obj, Function):
            if self.op == obj.op:
                return all([a == b for a, b in zip(self.args, obj.args)])
            else:
                return False
        else:
            return False
    def __hash__(self):
        return hash((Function, self.op, *(hash(_) for _ in self.args))) 
[docs]class LazySymbol(Symbol):
    """A generic way to make some lazily evaluated symbol.
    The value is dependant on a lambda function and some arbitrary arguments
    which will be called when the symbol is evaluated.
    Parameters
    ----------
    name : str
        Human readable string name for the symbol
    function : callable
        Function to call when evaluating the symbol
    *args : objects
        Arguments to pass to `function` when evaluating
    Examples
    --------
    >>> a = LazyVariable('a', lambda x: x**2, 10)
    >>> print(a)
    <Symbolic='(a*10)' @ 0x7fd6587e6760>
    >>> print((a*10).eval())
    1000
    """
[docs]    def __init__(self, name, function, *args):
        self.__name = name
        self.function = function
        self.args = args 
[docs]    def eval(self, **kwargs):
        return self.function(*self.args) 
    @property
    def name(self):
        return self.__name 
[docs]def collect(y):
    if hasattr(y, "op"):
        if not y._is_narg_expression_tree:
            y = y.to_nary_add_mul()  # simplification happens on nary trees
            if not hasattr(y, "op"):
                return y
        with simplification(allow_flagged=True):
            if y.op is operator_add:
                args = {}
                out = 0
                for x in y.args:
                    c, t = coefficient_and_term(x)
                    if t is None:  # just a coefficient/constant
                        out += x
                    else:
                        if t in args:
                            args[t] += c
                        else:
                            args[t] = c
                for k, v in args.items():
                    out += k * v
                return out
            elif len(y.args) > 0:
                # y is some other operator, which may have args that need collecting
                return y.op(*(collect(_) for _ in y.args))
            else:
                return y
    else:
        return y 
[docs]def coefficient_and_term(y):
    try:
        if y.op is operator_mul:
            if not y._is_narg_expression_tree:
                y = y.to_nary_add_mul()  # simplification happens on nary trees
            args = reduce_mul_args(y.args)
            if type(args[0]) is not Constant:
                coeff = Constant(1)
                term = np.prod(args)
            else:
                coeff = args[0]
                term = np.prod(args[1:])
            return coeff, term
        else:
            return Constant(1), y
    except AttributeError:
        # Not an operator...
        if type(y) is Constant:
            return y, None
        else:
            return Constant(1), y 
[docs]def expand_mul(y):
    with simplification(allow_flagged=True):
        if not hasattr(y, "op"):
            return y
        elif not y._is_narg_expression_tree:
            y = y.to_nary_add_mul()  # simplification happens on nary trees
        if y.op is operator_mul:
            terms = np.array([1], dtype="O")
            for x in y.args:
                try:
                    if x.op is operator_add:
                        terms = np.outer(terms, x.args)
                    else:
                        terms *= x
                except AttributeError:
                    terms *= x
            res = np.sum(terms)
            return res
        elif y.op is operator_add:
            return sum(expand_mul(_) for _ in y.args)
        else:
            return y 
[docs]def is_integer(n):
    """Checks if `n` is an integer.
    Parameters
    ----------
    n : str, float
        Input to check
    """
    try:
        float(n)
    except ValueError:
        return False
    else:
        return float(n).is_integer() 
[docs]def expand_pow(y):
    with simplification(allow_flagged=True):
        if not hasattr(y, "op"):
            return y
        elif not y._is_narg_expression_tree:
            y = y.to_nary_add_mul()  # simplification happens on nary trees
        if y.op is operator.pow:
            z = 1
            exp = y.args[1]
            base = y.args[0]
            try:
                if base.op is operator_mul:
                    for _ in base.args:
                        z *= _**exp
                    return z
                elif base.op is operator_add:
                    if type(exp) is Constant and is_integer(exp.value):
                        n = int(exp.value)
                        if n == 0:
                            return 1
                        terms = np.array([1], dtype="O")
                        for i in range(abs(n)):
                            terms = np.outer(terms, base.args)
                        if n > 0:
                            return terms.sum().collect()
                        else:
                            return (terms.sum().collect()) ** -1
                    else:
                        return y
                else:
                    return y
            except AttributeError:
                return y
        elif y.op is operator_add:
            return sum(expand_pow(_) for _ in y.args)
        elif y.op is operator_mul:
            return np.prod(tuple(expand_pow(_) for _ in y.args))
        else:
            return y 
[docs]def expand(y):
    with simplification(allow_flagged=True):
        if not isinstance(y, Function):
            return y  # Nothing to expand
        else:
            if not y._is_narg_expression_tree:
                y = y.to_nary_add_mul()  # simplification happens on nary trees
                if not isinstance(y, Function):
                    return y  # Nothing to expand
            if y.op is operator_mul:
                # Run through and expand any pows before the
                # full mul expansion
                y = Function(y.name, y.op, *[expand_pow(_) for _ in y.args])
                y = expand_mul(y)
                # Nothing left to expand as no add operator
                if not isinstance(y, Function):
                    return y
                elif y.op is operator_mul:
                    return y
            elif y.op is operator.pow:
                y = expand_pow(y)
                # Nothing left to expand as no add operator
                if y.op is operator.pow:
                    return y
            if y.op is not operator_add:
                # some other type of operator expand it's args
                if len(y.args) > 0:
                    return y.op(*(expand(_) for _ in y.args))
                else:
                    return y
            else:
                # We have a summation of terms from the initial expansion
                # so expand each term
                out = []
                for x in y.args:
                    try:
                        if x.op is operator_mul:
                            x = Function(x.name, x.op, *[expand_pow(_) for _ in x.args])
                            z = expand_mul(x)
                        elif x.op is operator.pow:
                            z = expand_pow(x)
                        else:
                            z = x
                        if hasattr(z, "op") and z.op is operator_add:
                            out.extend(z.args)
                        else:
                            out.append(z)
                    except AttributeError:
                        out.append(x)
                z = Function(y.name, y.op, *out)
                if y == z:
                    return z
                else:
                    return expand(z)