python/aws/amazon-braket-examples/examples/quantum_annealing/Dwave_StructuralImbalance/helpers/draw.py

draw.py
# Copyright 2020 D-Wave Systems Inc.
#
#    Licensed under the Apache License, Version 2.0 (the "License");
#    you may not use this file except in compliance with the License.
#    You may obtain a copy of the License at
#
#        http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS,
#    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#    See the License for the specific language governing permissions and
#    limitations under the License.

import sys
from collections import defaultdict
from itertools import product

import networkx as nx
from bokeh.io import output_notebook, show
from bokeh.models import (
    Circle,
    ColumnDataSource,
    HoverTool,
    LabelSet,
    MultiLine,
    PanTool,
    Plot,
    Range1d,
    Row,
    WheelZoomTool,
    ZoomInTool,
    ZoomOutTool,
)
from bokeh.models.graphs import EdgesAndLinkedNodes, from_networkx

# call output_notebook once on import, so we don't reload bokeh every time.
me = sys.modules[__name__]
if not hasattr(me, "bokeh_loaded"):
    output_notebook()
    bokeh_loaded = True


def draw(S, position=None, with_labels=False):
    """Plot the given signed social network.

    Args:
        S: The network
        position (dict, optional):
            The position for the nodes. If no position is provided, a layout will be calculated. If the nodes have
            'color' attributes, a Kamanda-Kawai layout will be used to group nodes of the same color together.
            Otherwise, a circular layout will be used.

    Returns:
        A dictionary of positions keyed by node.

    Examples:
    >>> import dwave_structural_imbalance_demo as sbdemo
    >>> gssn = sbdemo.GlobalSignedSocialNetwork()
    >>> nld_before = gssn.get_node_link_data('Syria', 2013)
    >>> nld_after = gssn.solve_structural_imbalance('Syria', 2013)
    # draw Global graph before solving; save node layout for reuse
    >>> position = sbdemo.draw('syria.png', nld_before)
    # draw the Global graph; reusing the above layout, and calculating a new grouped layout
    >>> sbdemo.draw('syria_imbalance.png', nld_after, position)
    >>> sbdemo.draw('syria_imbalance_grouped', nld_after)

    """

    # we need a consistent ordering of the edges
    edgelist = S.edges()
    nodelist = S.nodes()

    def layout_wrapper(S):
        pos = position
        if pos is None:
            try:
                # group bipartition if nodes are colored
                dist = defaultdict(dict)
                for u, v in product(nodelist, repeat=2):
                    if u == v:  # node has no distance from itself
                        dist[u][v] = 0
                    elif (
                        nodelist[u]["color"] == nodelist[v]["color"]
                    ):  # make same color nodes closer together
                        dist[u][v] = 1
                    else:  # make different color nodes further apart
                        dist[u][v] = 2
                pos = nx.kamada_kawai_layout(S, dist)
            except KeyError:
                # default to circular layout if nodes aren't colored
                pos = nx.circular_layout(S)
        return pos

    # call layout wrapper once with all nodes to store position for calls with partial graph
    position = layout_wrapper(S)

    plot = Plot(
        plot_width=600, plot_height=400, x_range=Range1d(-1.2, 1.2), y_range=Range1d(-1.2, 1.2)
    )
    tools = [WheelZoomTool(), ZoomInTool(), ZoomOutTool(), PanTool()]
    plot.add_tools(*tools)
    plot.toolbar.active_scroll = tools[0]

    def get_graph_renderer(S, line_dash):
        # we need a consistent ordering of the edges
        edgelist = S.edges()
        nodelist = S.nodes()

        # get the colors assigned to each edge based on friendly/hostile
        sign_edge_color = ["#87DACD" if S[u][v]["sign"] == 1 else "#FC9291" for u, v in edgelist]

        # get the colors assigned to each node by coloring
        try:
            coloring_node_color = [
                "#4378F8" if nodelist[v]["color"] else "#FFE897" for v in nodelist
            ]
        except KeyError:
            coloring_node_color = ["#FFFFFF" for __ in nodelist]

        graph_renderer = from_networkx(S, layout_wrapper)

        circle_size = 10
        graph_renderer.node_renderer.data_source.add(coloring_node_color, "color")
        graph_renderer.node_renderer.glyph = Circle(size=circle_size, fill_color="color")

        edge_size = 2
        graph_renderer.edge_renderer.data_source.add(sign_edge_color, "color")
        try:
            graph_renderer.edge_renderer.data_source.add(
                [S[u][v]["event_year"] for u, v in edgelist], "event_year"
            )
            graph_renderer.edge_renderer.data_source.add(
                [S[u][v]["event_description"] for u, v in edgelist], "event_description"
            )
            plot.add_tools(
                HoverTool(
                    tooltips=[("Year", "@event_year"), ("Description", "@event_description")],
                    line_policy="interp",
                )
            )
        except KeyError:
            pass
        graph_renderer.edge_renderer.glyph = MultiLine(line_color="color", line_dash=line_dash)

        graph_renderer.inspection_policy = EdgesAndLinkedNodes()

        return graph_renderer

    try:
        S_dash = S.edge_subgraph(((u, v) for u, v in edgelist if S[u][v]["frustrated"]))
        S_solid = S.edge_subgraph(((u, v) for u, v in edgelist if not S[u][v]["frustrated"]))
        plot.renderers.append(get_graph_renderer(S_dash, "dashed"))
        plot.renderers.append(get_graph_renderer(S_solid, "solid"))
    except KeyError:
        plot.renderers.append(get_graph_renderer(S, "solid"))

    plot.background_fill_color = "#202239"

    positions = layout_wrapper(S)
    if with_labels:
        data = {"xpos": [], "ypos": [], "label": []}
        for label, pos in positions.items():
            data["label"].append(label)
            data["xpos"].append(pos[0])
            data["ypos"].append(pos[1])

        labels = LabelSet(
            x="xpos",
            y="ypos",
            text="label",
            level="glyph",
            source=ColumnDataSource(data),
            x_offset=-5,
            y_offset=10,
            text_color="#F5F7FB",
            text_font_size="12pt",
        )
        plot.add_layout(labels)

    show(Row(plot))

    return positions