"""A components sub-module containing classes for detecting intensity fluctuations at a
physical point in a model.
These Readout components essentially describe baseband and broadband detectors such as
DC and RF demodulated photodiodes typically used in optical experiments.
"""
import types
from collections import defaultdict
import numpy as np
import finesse
from finesse.components.general import Connector, borrows_nodes
from finesse.components.node import Node, NodeDirection, NodeType, Port
from finesse.components.workspace import ConnectorWorkspace
from finesse.detectors import pdtypes
from finesse.detectors.compute.quantum import QShot0Workspace, QShotNWorkspace
from finesse.element import ModelElement
from finesse.parameter import float_parameter
doc_readout_param = """"
Parameters
----------
name : str
    Name of readout element
optical_node : Node
    Node object which this readout element should look at
pdtype : str, dict
    A name of a pdtype defintion or a dict represeting a pdtype definition
"""
class ReadoutWorkspace(ConnectorWorkspace):
    pass
# IMPORTANT: renaming this class impacts the katscript spec and should be avoided!
class _Readout(Connector):
    def __init__(
        self, name: str, optical_node: Node, pdtype=None, output_detectors: bool = False
    ):
        """Abstract class that provides basic functionality similar to all Readouts.
        Underscore because users should not be accessing it directly.
        Parameters
        ----------
        name : str
            Name of readout element
        optical_node : Node
            Node object which this readout element should look at
        pdtype : str, dict
            A name of a pdtype defintion or a dict represeting a pdtype definition
        output_detectors : bool, optional
            _description_, by default False
        """
        super().__init__(name)
        self.pdtype = pdtype
        self.__output_detectors = output_detectors
        self._add_port("p1", NodeType.OPTICAL)
        if optical_node is not None:
            port = optical_node if isinstance(optical_node, Port) else optical_node.port
            other_node = tuple(o for o in port.nodes if o is not optical_node)[0]
            self.p1._add_node("i", None, optical_node)
            self.p1._add_node("o", None, other_node)
        else:
            self.p1._add_node("i", NodeDirection.INPUT)
            self.p1._add_node("o", NodeDirection.OUTPUT)
    def _on_add(self, model):
        if model is not self.p1._model:
            raise Exception(
                f"{repr(self)} is using a node {self.node} from a different model"
            )
    def _on_remove(self):
        for output in self.outputs.__dict__.values():
            self._model.remove(output)
    def _get_output_workspaces(self, model):
        return None
    @property
    def optical_node(self):
        if self.p1.i.component != self:
            return self.p1.i
    @property
    def has_mask(self):
        return False
    @property
    def output_detectors(self):
        return self.__output_detectors
    @output_detectors.setter
    def output_detectors(self, value: bool):
        self.__output_detectors = value
# IMPORTANT: renaming this class impacts the katscript spec and should be avoided!
[docs]class ReadoutDetectorOutput(ModelElement):
    """A placeholder element that represents a detector output generated by a Readout
    element.
    Notes
    -----
    These should not be created directly by a user.
    It is internally created and added by a Readout component.
    """
    def __init__(self, name: str, readout: _Readout):
        super().__init__(name)
        self.__readout = readout
    @property
    def readout(self):
        return self.__readout 
[docs]@borrows_nodes()
# IMPORTANT: renaming this class impacts the katscript spec and should be avoided!
class ReadoutDC(_Readout):
    def __init__(
        self,
        name: str,
        optical_node: Node = None,
        pdtype=None,
        output_detectors: bool = False,
    ):
        """A Readout component which represents a photodiode measuring the intensity of
        some incident field. Audio band intensity signals present in the incident
        optical field are converted into an electrical signal and output at the
        `self.DC` port, which has a single `self.DC.o` node.
        Parameters
        ----------
        name : str
            Name of readout element
        optical_node : Node
            Node object which this readout element should look at
        pdtype : str, dict
            A name of a pdtype defintion or a dict represeting a pdtype definition
        output_detectors : bool, optional
            _description_, by default False
        """
        super().__init__(
            name, optical_node, pdtype=pdtype, output_detectors=output_detectors
        )
        self.pdtype = pdtypes.get_pdtype(pdtype)
        self._add_port("DC", NodeType.ELECTRICAL)
        self.DC._add_node("o", NodeDirection.OUTPUT)
        self._register_node_coupling("P1i_DC", self.p1.i, self.DC.o)
        self.outputs = types.SimpleNamespace()
        self.outputs.DC = f"{self.name}_DC"
    def _on_add(self, model):
        super()._on_add(model)
        model.add(ReadoutDetectorOutput(f"{self.name}_DC", self))
    def _get_workspace(self, sim):
        if sim.signal:
            has_DC_node = self.DC.o.full_name in sim.signal.nodes
            if not has_DC_node:
                return None  # Don't do anything if no nodes included
            ws = ReadoutWorkspace(self, sim)
            ws.prev_carrier_solve_num = -1
            ws.I = np.eye(sim.model_settings.num_HOMs, dtype=np.complex128)
            ws.signal.add_fill_function(self._fill_matrix, True)
            ws.frequencies = sim.signal.signal_frequencies[self.DC.o].frequencies
            ws.is_segmented = self.pdtype is not None
            if ws.is_segmented:
                ws.K = pdtypes.construct_segment_beat_matrix(
                    sim.model.mode_index_map, self.pdtype  # , sparse_output=True
                )
            return ws
        else:
            return None
    def _get_output_workspaces(self, sim):
        from finesse.detectors import PowerDetector, QuantumShotNoiseDetector
        from finesse.detectors.compute.power import PD0Workspace
        from finesse.detectors.workspace import OutputInformation
        wss = []
        # Setup a DC output photodiode detector for
        # using for outputs
        oinfo = OutputInformation(
            self.name + "_DC",
            PowerDetector,
            (self.p1.i,),
            np.float64,
            "W",
            None,
            "W",
            True,
            False,
        )
        ws = PD0Workspace(self, sim, oinfo=oinfo, pdtype=self.pdtype)
        wss.append(ws)
        if sim.signal:
            oinfo = OutputInformation(
                self.name + "_shot_noise",
                QuantumShotNoiseDetector,
                (self.p1.i,),
                np.float64,
                "W/rtHz",
                None,
                "ASD",
                True,
                False,
            )
            wss.append(QShot0Workspace(self, sim, False, output_info=oinfo))
        return wss
    def _fill_matrix(self, ws):
        """Computing E.conj() * upper + E * lower.conj()"""
        # if the previous fill was done with this carrier then there
        # is no need to refill it...
        if ws.prev_carrier_solve_num == ws.sim.carrier.num_solves:
            return
        for freq in ws.sim.signal.optical_frequencies.frequencies:
            # Get the carrier HOMs for this frequency
            cidx = freq.audio_carrier_index
            Ec = np.conjugate(ws.sim.carrier.node_field_vector(self.p1.i, cidx))
            for efreq in ws.frequencies:
                if ws.signal.connections.P1i_DC_idx > -1:
                    with ws.sim.signal.component_edge_fill3(
                        ws.owner_id,
                        ws.signal.connections.P1i_DC_idx,
                        freq.index,
                        efreq.index,
                    ) as mat:
                        if ws.is_segmented:
                            mat[:] = np.dot(ws.K, Ec)
                        else:
                            mat[:] = Ec
        # store what carrier solve number this fill was done with
        ws.prev_carrier_solve_num = ws.sim.carrier.num_solves 
[docs]@borrows_nodes()
@float_parameter("f", "Frequency")
@float_parameter("phase", "Phase")
# IMPORTANT: renaming this class impacts the katscript spec and should be avoided!
class ReadoutRF(_Readout):
    def __init__(
        self,
        name,
        optical_node=None,
        *,
        f=None,
        phase=0,
        output_detectors=False,
        pdtype=None,
    ):
        super().__init__(
            name, optical_node, pdtype=pdtype, output_detectors=output_detectors
        )
        self.f = f
        self.phase = phase
        self._add_port("I", NodeType.ELECTRICAL)
        self.I._add_node("o", NodeDirection.OUTPUT)
        self._add_port("Q", NodeType.ELECTRICAL)
        self.Q._add_node("o", NodeDirection.OUTPUT)
        self._register_node_coupling("P1i_I", self.p1.i, self.I.o)
        self._register_node_coupling("P1i_Q", self.p1.i, self.Q.o)
        self.outputs = types.SimpleNamespace()
        self.outputs.I = f"{self.name}_I"
        self.outputs.Q = f"{self.name}_Q"
        self.outputs.DC = f"{self.name}_DC"
    @property
    def optical_node(self):
        if self.p1.i.component != self:
            return self.p1.i
    def _on_add(self, model):
        super()._on_add(model)
        model.add(ReadoutDetectorOutput(self.name + "_DC", self))
        model.add(ReadoutDetectorOutput(self.name + "_I", self))
        model.add(ReadoutDetectorOutput(self.name + "_Q", self))
    def _get_workspace(self, sim):
        if sim.signal:
            has_I_node = self.I.o.full_name in sim.signal.nodes
            has_Q_node = self.Q.o.full_name in sim.signal.nodes
            if not (has_I_node or has_Q_node):
                return None  # Don't do anything if no nodes included
            ws = ReadoutWorkspace(self, sim)
            ws.prev_carrier_solve_num = -1
            ws.signal.add_fill_function(self._fill_matrix, True)
            ws.frequencies = sim.signal.signal_frequencies[
                self.I.o if has_I_node else self.Q.o
            ].frequencies
            ws.dc_node_id = sim.carrier.node_id(self.p1.i)
            ws.is_segmented = self.pdtype is not None
            if ws.is_segmented:
                ws.K = pdtypes.construct_segment_beat_matrix(
                    sim.model.mode_index_map, self.pdtype  # , sparse_output=True
                )
            return ws
        else:
            return None
    def _get_output_workspaces(self, sim):
        from finesse.detectors import (
            PowerDetector,
            PowerDetectorDemod1,
            QuantumShotNoiseDetectorDemod1,
        )
        from finesse.detectors.compute.power import PD0Workspace, PD1Workspace
        from finesse.detectors.workspace import OutputInformation
        wss = []
        for quadrature in ("I", "Q"):
            # Setup a single demodulation photodiode detector for
            # using for outputs
            oinfo = OutputInformation(
                self.name + "_" + quadrature,
                PowerDetectorDemod1,
                (self.p1.i,),
                np.float64,
                "W",
                None,
                "W",
                True,
                False,
            )
            poff = 90 if quadrature == "Q" else 0
            ws = PD1Workspace(
                self,
                sim,
                self.f,
                self.phase,
                phase_offset=poff,
                oinfo=oinfo,
                pdtype=self.pdtype,
            )
            wss.append(ws)
        # Setup a DC output photodiode detector for
        # using for outputs
        oinfo = OutputInformation(
            self.name + "_DC",
            PowerDetector,
            (self.p1.i,),
            np.float64,
            "W",
            None,
            "W",
            True,
            False,
        )
        ws = PD0Workspace(self, sim, oinfo=oinfo)
        wss.append(ws)
        if sim.signal:
            oinfo = OutputInformation(
                self.name + "_shot_noise",
                QuantumShotNoiseDetectorDemod1,
                (self.p1.i,),
                np.float64,
                "W/rtHz",
                None,
                "ASD",
                True,
                False,
            )
            wss.append(
                QShotNWorkspace(
                    self,
                    sim,
                    [
                        (self.f, self.phase),
                    ],
                    False,
                    output_info=oinfo,
                )
            )
        return wss
    def _fill_matrix(self, ws):
        if ws.prev_carrier_solve_num == ws.sim.carrier.num_solves:
            return
        # extra factor of two we do not apply here as we work
        # directly with amplitudes from the matrix solution
        # need one half gain from demod. Other factor of two from
        # signal scaling and 0.5 from second demod cancel out
        factorI = (
            0.5
            * ws.sim.model_settings.EPSILON0_C
            * np.exp(-1j * ws.values.phase * finesse.constants.DEG2RAD)
        )
        factorQ = (
            0.5
            * ws.sim.model_settings.EPSILON0_C
            * np.exp(-1j * (ws.values.phase + 90) * finesse.constants.DEG2RAD)
        )
        terms = defaultdict(list)
        for f1 in ws.sim.carrier.optical_frequencies.frequencies:
            for f2 in ws.sim.carrier.optical_frequencies.frequencies:
                df = f1.f - f2.f
                # Get the carrier HOMs for this frequency
                E1 = ws.sim.carrier.node_field_vector(self.p1.i, f1.index)
                E1c = np.conjugate(E1)
                E2 = ws.sim.carrier.node_field_vector(self.p1.i, f2.index)
                E2c = np.conjugate(E2)
                if df == -ws.values.f:
                    if ws.signal.connections.P1i_I_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorI * E1c)
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorI.conjugate() * E2c)
                    if ws.signal.connections.P1i_Q_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorQ * E1c)
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorQ.conjugate() * E2c)
                if df == ws.values.f:
                    if ws.signal.connections.P1i_I_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorI.conjugate() * E1c)
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_I_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorI * E2c)
                    if ws.signal.connections.P1i_Q_idx >= 0:
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f2.index)[
                                "audio_lower_index"
                            ],
                        )
                        terms[key].append(factorQ.conjugate() * E1c)
                        key = (
                            ws.owner_id,
                            ws.signal.connections.P1i_Q_idx,
                            ws.sim.carrier.optical_frequencies.get_info(f1.index)[
                                "audio_upper_index"
                            ],
                        )
                        terms[key].append(factorQ * E2c)
        for key, values in terms.items():
            total = sum(values)
            if ws.is_segmented:
                total = np.dot(ws.K, total)
            with ws.sim.signal.component_edge_fill3(*key, 0) as mat:
                mat[:] = total
        # store previous carrier solve number this fill was done with
        # so we don't have to repeat it
        ws.prev_carrier_solve_num = ws.sim.carrier.num_solves