Source code for finesse.plotting.graph

"""Graph plotting."""

from collections import defaultdict
import matplotlib.pyplot as plt
import networkx as nx
from ..utilities import graph_layouts, option_list
from ..utilities.graph import remove_orphans
from .tools import _in_ipython
from matplotlib import cm
from matplotlib.colors import rgb2hex


[docs]def plot_graph( network, layout, graphviz=False, **kwargs, ): from ..env import has_pygraphviz if graphviz and not has_pygraphviz(): raise ModuleNotFoundError( "The graphviz option requires pygraphviz and graphviz to be installed" ) plotter = plot_graphviz if graphviz else plot_nx_graph plotter(network, layout, **kwargs)
[docs]def plot_nx_graph( network, layout, node_labels=True, node_attrs=False, edge_attrs=False, node_color_key=None, edge_color_key=None, label_font_size=12, attr_font_size=6, edge_font_size=6, **kwargs, ): """Plot graph with NetworkX. Parameters ---------- network : :class:`networkx.Graph` The network to plot. layout : str The layout type to use. Any layout algorithm provided by :mod:`networkx.drawing.layout` is supported. node_labels : :class:`bool`, optional Show node names; defaults to True. node_attrs : :class:`bool` or :class:`list`, optional Show node data. This can be `True`, in which case all node data is shown, or a list, in which case only the specified keys are shown. Defaults to `True`. edge_attrs : :class:`bool` or :class:`list`, optional Show edge data. This can be `True`, in which case all edge data is shown, or a list, in which case only the specified keys are shown. Defaults to `True`. node_color_key : callable, optional Key function accepting a node and its attribute :class:`dict` and returning a group. Each group is assigned a unique color. If not specified, nodes are not colored. edge_color_key : callable, optional Key function accepting an edge (u, v) and its attribute :class:`dict` and returning a group. Each group is assigned a unique color. If not specified, edges are not colored. label_font_size, attr_font_size, edge_font_size : :class:`int`, optional Font size for node labels, attributes and edges. Defaults to 12, 6 and 6, respectively. Other Parameters ---------------- kwargs Anything else supported by :func:`networkx.drawing.nx_pylab.draw`. Raises ------ ValueError If the specified layout is not supported. Exception If the graph cannot be represented with the specified layout. """ from ..utilities import stringify if node_color_key is not None: if "node_color" in kwargs: raise ValueError( "cannot specify both 'node_color' and 'node_color_key' arguments" ) # Assign node colors. cycler = iter(plt.rcParams["axes.prop_cycle"].by_key()["color"]) group_colors = defaultdict(lambda: next(cycler)) kwargs["node_color"] = [ group_colors[node_color_key(node, data)] for node, data in network.nodes(data=True) ] if edge_color_key is not None: if "edge_color" in kwargs: raise ValueError( "cannot specify both 'edge_color' and 'edge_color_key' arguments" ) # Assign edge colors. cycler = iter(plt.rcParams["axes.prop_cycle"].by_key()["color"]) group_colors = defaultdict(lambda: next(cycler)) kwargs["edge_color"] = [ group_colors[edge_color_key((u, v), data)] for u, v, data in network.edges(data=True) ] layouts = graph_layouts() try: posfunc = layouts[layout.casefold()] except KeyError: choices = option_list(layouts) raise ValueError( f"Layout '{layout}' is not available in NetworkX (choose from {choices})." ) try: pos = posfunc(network) except nx.NetworkXException as e: if "G is not planar" in str(e): raise Exception( "Graph cannot be represented with a planar layout. Try a different layout." ) from e nx.draw( network, pos, with_labels=node_labels, verticalalignment="bottom", font_size=label_font_size, **kwargs, ) if node_attrs: data = network.nodes(data=True) if node_attrs is not True: # Needs to be like this! # Show only certain data. data = [ ( node, { key: value for key, value in node_data.items() if key in node_attrs }, ) for node, node_data in data ] node_labels = { node: "\n".join( [f"{key}={stringify(value)}" for key, value in node_attrs.items()] ) for node, node_attrs in data } nx.draw_networkx_labels( network, pos, labels=node_labels, verticalalignment="top", font_size=attr_font_size, ) if edge_attrs: data = network.edges(data=True) if edge_attrs is not True: # Needs to be like this! # Show only certain data. data = ( ( u, v, { key: value for key, value in edge_data.items() if key in edge_attrs }, ) for u, v, edge_data in data ) edge_labels = { (u, v): "\n".join( [f"{key}={stringify(value)}" for key, value in edge_attrs.items()] ) for u, v, edge_attrs in data } nx.draw_networkx_edge_labels( network, pos, edge_labels=edge_labels, font_size=edge_font_size, ) plt.show()
[docs]def plot_graphviz(network, layout): """Plot graph with graphviz. The `pygraphviz` Python package must be installed and available on the current Python path, and `graphviz` must be available on the system path. Parameters ---------- network : :class:`networkx.Graph` The network to plot. layout : str The layout type to use. Any layout algorithm provided by graphviz is supported. Raises ------ ValueError If the specified layout is not supported. ImportError If graphviz or pygraphviz is not installed. """ from networkx.drawing.nx_agraph import view_pygraphviz layouts = ("neato", "dot", "fdp", "sfdp", "circo") gvlayout = layout.casefold() if gvlayout not in layouts: choices = option_list(layouts) raise ValueError( f"Layout '{layout}' is not available in graphviz (choose from {choices})." ) view_pygraphviz(network, prog=gvlayout)
[docs]def graphviz_draw( model=None, network=None, draw_labels=True, angle=0, overlap=True, ratio=0.45, edge_len=1.0, size=(13, 7), pad=(0.5, 0.5), format="svg", maxiter=500, layout="neato", mode="sgd", ): """This should get merged with plot_graphviz at some point. Draws a |graphviz| figure using |neato| layout. The default settings are tested to produce a passable drawing of the aLIGO DRMI graph. Parameters ---------- angle : float or bool The angle parameter rotates the graph by |angle| degrees relative to the first edge in the graph, which most of the time is the edge coming out of the laser. Set |angle=False| to disable rotation and let graphviz decide how to rotate the graph. overlap : bool or str Setting for how graphviz deals with node overlaps. Set to False for graphviz to attempt to remove overlaps. Note that overlap removal runs as a post-processing step after initial layout and usually makes the graph look worse. ratio : float Post processing step to stretch the graph. Used for stretching horizontally to compoensate for wider nodes to fit node labels. Notes ----- The svg format sometimes crops the image too hard, which results in clipped nodes or edges, if that happens increase the |pad| graph_attr or use the |png| format. Examples -------- .. code-block:: import finesse.ligo import finesse.plotting kat = finesse.ligo.make_aligo() finesse.plotting.graph.graphviz_draw(kat) """ from ..env import has_pygraphviz if not has_pygraphviz(): raise ModuleNotFoundError("Requires pygraphviz and graphviz to be installed") if network is None: network = model.optical_network G = remove_orphans(network, inplace=False) A = nx.drawing.nx_agraph.to_agraph(G) # remove unnecessary metadata from DOT file for node in A.nodes(): for k in node.attr.keys(): node.attr[k] = "" for edge in A.edges(): for k in edge.attr.keys(): edge.attr[k] = "" A.graph_attr["mode"] = mode A.graph_attr["maxiter"] = maxiter A.graph_attr["size"] = f"{size[0]},{size[1]}" A.graph_attr["pad"] = f"{pad[0]},{pad[1]}" A.graph_attr["normalize"] = angle A.graph_attr["overlap"] = overlap A.edge_attr["len"] = edge_len if draw_labels: A.node_attr["shape"] = "oval" A.graph_attr["ratio"] = ratio else: A.node_attr["shape"] = "circle" A.node_attr["style"] = "filled" A.node_attr["label"] = " " byt = A.draw(format=format, prog=layout) if _in_ipython(): from IPython.display import Image, SVG if format == "svg": out = SVG(byt) elif format in ["png"]: out = Image(byt) else: raise ValueError(f"unknown {format}") else: out = byt # TODO add option to write to file return out
[docs]def graphviz_draw_beam_trace( model=None, network=None, draw_labels=True, angle=0, overlap=True, ratio=0.45, edge_len=1.0, size=(13, 7), pad=(0.5, 0.5), format="svg", maxiter=500, layout="neato", mode="sgd", cmap=cm.tab10, ): colors = { dep: rgb2hex(cmap.colors[i]) for i, dep in enumerate(model.trace_forest.dependencies) } node_colors = { n.full_name: colors[model.trace_forest.find_dependency_from_node(n)] for n in model.optical_nodes } network = model.optical_network G = remove_orphans(network, inplace=False) A = nx.drawing.nx_agraph.to_agraph(G) # remove unnecessary metadata from DOT file for node in A.nodes(): for k in node.attr.keys(): node.attr[k] = "" node.attr["fillcolor"] = node_colors[node] node.attr["tooltip"] = model.get(node).q for edge in A.edges(): for k in edge.attr.keys(): edge.attr[k] = "" A.graph_attr["mode"] = mode A.graph_attr["maxiter"] = maxiter A.graph_attr["size"] = f"{size[0]},{size[1]}" A.graph_attr["pad"] = f"{pad[0]},{pad[1]}" A.graph_attr["normalize"] = angle A.graph_attr["overlap"] = overlap A.edge_attr["len"] = edge_len if draw_labels: A.node_attr["shape"] = "oval" A.node_attr["style"] = "filled" A.graph_attr["ratio"] = ratio else: A.node_attr["shape"] = "circle" A.node_attr["style"] = "filled" A.node_attr["label"] = " " for dep in model.trace_forest.dependencies: A.add_node(dep.name) A.add_edge(dep.name, dep.node.full_name) A.add_edge(dep.name, dep.node.opposite.full_name) node = A.get_node(dep.name) node.attr["fillcolor"] = colors[dep] node.attr["shape"] = "rectangle" byt = A.draw(format=format, prog=layout) from IPython.display import Image, SVG if format == "svg": out = SVG(byt) elif format in ["png"]: out = Image(byt) else: raise ValueError(f"unknown {format}") return out