"""Graph plotting."""
from collections import defaultdict
import matplotlib.pyplot as plt
import networkx as nx
from ..utilities import graph_layouts, option_list
from ..utilities.graph import remove_orphans
from .tools import _in_ipython
from matplotlib import cm
from matplotlib.colors import rgb2hex
[docs]def plot_graph(
    network,
    layout,
    graphviz=False,
    **kwargs,
):
    from ..env import has_pygraphviz
    if graphviz and not has_pygraphviz():
        raise ModuleNotFoundError(
            "The graphviz option requires pygraphviz and graphviz to be installed"
        )
    plotter = plot_graphviz if graphviz else plot_nx_graph
    plotter(network, layout, **kwargs) 
[docs]def plot_nx_graph(
    network,
    layout,
    node_labels=True,
    node_attrs=False,
    edge_attrs=False,
    node_color_key=None,
    edge_color_key=None,
    label_font_size=12,
    attr_font_size=6,
    edge_font_size=6,
    **kwargs,
):
    """Plot graph with NetworkX.
    Parameters
    ----------
    network : :class:`networkx.Graph`
        The network to plot.
    layout : str
        The layout type to use. Any layout algorithm provided by
        :mod:`networkx.drawing.layout` is supported.
    node_labels : :class:`bool`, optional
        Show node names; defaults to True.
    node_attrs : :class:`bool` or :class:`list`, optional
        Show node data. This can be `True`, in which case all node data is shown, or a
        list, in which case only the specified keys are shown. Defaults to `True`.
    edge_attrs : :class:`bool` or :class:`list`, optional
        Show edge data. This can be `True`, in which case all edge data is shown, or a
        list, in which case only the specified keys are shown. Defaults to `True`.
    node_color_key : callable, optional
        Key function accepting a node and its attribute :class:`dict` and returning a
        group. Each group is assigned a unique color. If not specified, nodes are not
        colored.
    edge_color_key : callable, optional
        Key function accepting an edge (u, v) and its attribute :class:`dict` and
        returning a group. Each group is assigned a unique color. If not specified,
        edges are not colored.
    label_font_size, attr_font_size, edge_font_size : :class:`int`, optional
        Font size for node labels, attributes and edges. Defaults to 12, 6 and 6,
        respectively.
    Other Parameters
    ----------------
    kwargs
        Anything else supported by :func:`networkx.drawing.nx_pylab.draw`.
    Raises
    ------
    ValueError
        If the specified layout is not supported.
    Exception
        If the graph cannot be represented with the specified layout.
    """
    from ..utilities import stringify
    if node_color_key is not None:
        if "node_color" in kwargs:
            raise ValueError(
                "cannot specify both 'node_color' and 'node_color_key' arguments"
            )
        # Assign node colors.
        cycler = iter(plt.rcParams["axes.prop_cycle"].by_key()["color"])
        group_colors = defaultdict(lambda: next(cycler))
        kwargs["node_color"] = [
            group_colors[node_color_key(node, data)]
            for node, data in network.nodes(data=True)
        ]
    if edge_color_key is not None:
        if "edge_color" in kwargs:
            raise ValueError(
                "cannot specify both 'edge_color' and 'edge_color_key' arguments"
            )
        # Assign edge colors.
        cycler = iter(plt.rcParams["axes.prop_cycle"].by_key()["color"])
        group_colors = defaultdict(lambda: next(cycler))
        kwargs["edge_color"] = [
            group_colors[edge_color_key((u, v), data)]
            for u, v, data in network.edges(data=True)
        ]
    layouts = graph_layouts()
    try:
        posfunc = layouts[layout.casefold()]
    except KeyError:
        choices = option_list(layouts)
        raise ValueError(
            f"Layout '{layout}' is not available in NetworkX (choose from {choices})."
        )
    try:
        pos = posfunc(network)
    except nx.NetworkXException as e:
        if "G is not planar" in str(e):
            raise Exception(
                "Graph cannot be represented with a planar layout. Try a different layout."
            ) from e
    nx.draw(
        network,
        pos,
        with_labels=node_labels,
        verticalalignment="bottom",
        font_size=label_font_size,
        **kwargs,
    )
    if node_attrs:
        data = network.nodes(data=True)
        if node_attrs is not True:  # Needs to be like this!
            # Show only certain data.
            data = [
                (
                    node,
                    {
                        key: value
                        for key, value in node_data.items()
                        if key in node_attrs
                    },
                )
                for node, node_data in data
            ]
        node_labels = {
            node: "\n".join(
                [f"{key}={stringify(value)}" for key, value in node_attrs.items()]
            )
            for node, node_attrs in data
        }
        nx.draw_networkx_labels(
            network,
            pos,
            labels=node_labels,
            verticalalignment="top",
            font_size=attr_font_size,
        )
    if edge_attrs:
        data = network.edges(data=True)
        if edge_attrs is not True:  # Needs to be like this!
            # Show only certain data.
            data = (
                (
                    u,
                    v,
                    {
                        key: value
                        for key, value in edge_data.items()
                        if key in edge_attrs
                    },
                )
                for u, v, edge_data in data
            )
        edge_labels = {
            (u, v): "\n".join(
                [f"{key}={stringify(value)}" for key, value in edge_attrs.items()]
            )
            for u, v, edge_attrs in data
        }
        nx.draw_networkx_edge_labels(
            network,
            pos,
            edge_labels=edge_labels,
            font_size=edge_font_size,
        )
    plt.show() 
[docs]def plot_graphviz(network, layout):
    """Plot graph with graphviz.
    The `pygraphviz` Python package must be installed and available on the current
    Python path, and `graphviz` must be available on the system path.
    Parameters
    ----------
    network : :class:`networkx.Graph`
        The network to plot.
    layout : str
        The layout type to use. Any layout algorithm provided by graphviz is supported.
    Raises
    ------
    ValueError
        If the specified layout is not supported.
    ImportError
        If graphviz or pygraphviz is not installed.
    """
    from networkx.drawing.nx_agraph import view_pygraphviz
    layouts = ("neato", "dot", "fdp", "sfdp", "circo")
    gvlayout = layout.casefold()
    if gvlayout not in layouts:
        choices = option_list(layouts)
        raise ValueError(
            f"Layout '{layout}' is not available in graphviz (choose from {choices})."
        )
    view_pygraphviz(network, prog=gvlayout) 
[docs]def graphviz_draw(
    model=None,
    network=None,
    draw_labels=True,
    angle=0,
    overlap=True,
    ratio=0.45,
    edge_len=1.0,
    size=(13, 7),
    pad=(0.5, 0.5),
    format="svg",
    maxiter=500,
    layout="neato",
    mode="sgd",
):
    """This should get merged with plot_graphviz at some point.
    Draws a |graphviz| figure using |neato| layout.
    The default settings are tested to produce a passable drawing of the
    aLIGO DRMI graph.
    Parameters
    ----------
    angle : float or bool
        The angle parameter rotates the graph by |angle| degrees relative to the
        first edge in the graph, which most of the time is the edge coming out
        of the laser. Set |angle=False| to disable rotation and let graphviz
        decide how to rotate the graph.
    overlap : bool or str
        Setting for how graphviz deals with node overlaps. Set to False for
        graphviz to attempt to remove overlaps. Note that overlap removal runs as
        a post-processing step after initial layout and usually makes the graph look
        worse.
    ratio : float
        Post processing step to stretch the graph. Used for stretching horizontally
        to compoensate for wider nodes to fit node labels.
    Notes
    -----
    The svg format sometimes crops the image too hard, which results in clipped
    nodes or edges, if that happens increase the |pad| graph_attr or use the
    |png| format.
    Examples
    --------
    .. code-block::
        import finesse.ligo
        import finesse.plotting
        kat = finesse.ligo.make_aligo()
        finesse.plotting.graph.graphviz_draw(kat)
    """
    from ..env import has_pygraphviz
    if not has_pygraphviz():
        raise ModuleNotFoundError("Requires pygraphviz and graphviz to be installed")
    if network is None:
        network = model.optical_network
    G = remove_orphans(network, inplace=False)
    A = nx.drawing.nx_agraph.to_agraph(G)
    # remove unnecessary metadata from DOT file
    for node in A.nodes():
        for k in node.attr.keys():
            node.attr[k] = ""
    for edge in A.edges():
        for k in edge.attr.keys():
            edge.attr[k] = ""
    A.graph_attr["mode"] = mode
    A.graph_attr["maxiter"] = maxiter
    A.graph_attr["size"] = f"{size[0]},{size[1]}"
    A.graph_attr["pad"] = f"{pad[0]},{pad[1]}"
    A.graph_attr["normalize"] = angle
    A.graph_attr["overlap"] = overlap
    A.edge_attr["len"] = edge_len
    if draw_labels:
        A.node_attr["shape"] = "oval"
        A.graph_attr["ratio"] = ratio
    else:
        A.node_attr["shape"] = "circle"
        A.node_attr["style"] = "filled"
        A.node_attr["label"] = " "
    byt = A.draw(format=format, prog=layout)
    if _in_ipython():
        from IPython.display import Image, SVG
        if format == "svg":
            out = SVG(byt)
        elif format in ["png"]:
            out = Image(byt)
        else:
            raise ValueError(f"unknown {format}")
    else:
        out = byt
    # TODO add option to write to file
    return out 
[docs]def graphviz_draw_beam_trace(
    model=None,
    network=None,
    draw_labels=True,
    angle=0,
    overlap=True,
    ratio=0.45,
    edge_len=1.0,
    size=(13, 7),
    pad=(0.5, 0.5),
    format="svg",
    maxiter=500,
    layout="neato",
    mode="sgd",
    cmap=cm.tab10,
):
    colors = {
        dep: rgb2hex(cmap.colors[i])
        for i, dep in enumerate(model.trace_forest.dependencies)
    }
    node_colors = {
        n.full_name: colors[model.trace_forest.find_dependency_from_node(n)]
        for n in model.optical_nodes
    }
    network = model.optical_network
    G = remove_orphans(network, inplace=False)
    A = nx.drawing.nx_agraph.to_agraph(G)
    # remove unnecessary metadata from DOT file
    for node in A.nodes():
        for k in node.attr.keys():
            node.attr[k] = ""
        node.attr["fillcolor"] = node_colors[node]
        node.attr["tooltip"] = model.get(node).q
    for edge in A.edges():
        for k in edge.attr.keys():
            edge.attr[k] = ""
    A.graph_attr["mode"] = mode
    A.graph_attr["maxiter"] = maxiter
    A.graph_attr["size"] = f"{size[0]},{size[1]}"
    A.graph_attr["pad"] = f"{pad[0]},{pad[1]}"
    A.graph_attr["normalize"] = angle
    A.graph_attr["overlap"] = overlap
    A.edge_attr["len"] = edge_len
    if draw_labels:
        A.node_attr["shape"] = "oval"
        A.node_attr["style"] = "filled"
        A.graph_attr["ratio"] = ratio
    else:
        A.node_attr["shape"] = "circle"
        A.node_attr["style"] = "filled"
        A.node_attr["label"] = " "
    for dep in model.trace_forest.dependencies:
        A.add_node(dep.name)
        A.add_edge(dep.name, dep.node.full_name)
        A.add_edge(dep.name, dep.node.opposite.full_name)
        node = A.get_node(dep.name)
        node.attr["fillcolor"] = colors[dep]
        node.attr["shape"] = "rectangle"
    byt = A.draw(format=format, prog=layout)
    from IPython.display import Image, SVG
    if format == "svg":
        out = SVG(byt)
    elif format in ["png"]:
        out = Image(byt)
    else:
        raise ValueError(f"unknown {format}")
    return out