Source code for kmapper.plotlyviz

from __future__ import division
from .utils import deprecated_alias

import numpy as np

from .visuals import (
    _scale_color_values,
    _size_node,
    _format_cluster_statistics,
    _node_color_function,
    _format_meta,
    _graph_data_distribution,
    _tooltip_components,
)

try:
    import igraph as ig
    import plotly.graph_objs as go
    import ipywidgets as ipw
    import plotly.io as pio
except ImportError:
    print(
        """To use the plotly visualization tools, you must have the packages igraph, plotly, and ipywidgets installed in your environment."""
        """ It looks like at least one of these is missing.  Please install again with"""
        """\n\n\t`pip install igraph plotly ipywidgets`\n\nand try again"""
    )
    raise


default_colorscale = [
    [0.0, "rgb(68, 1, 84)"],  # Viridis
    [0.1, "rgb(72, 35, 116)"],
    [0.2, "rgb(64, 67, 135)"],
    [0.3, "rgb(52, 94, 141)"],
    [0.4, "rgb(41, 120, 142)"],
    [0.5, "rgb(32, 144, 140)"],
    [0.6, "rgb(34, 167, 132)"],
    [0.7, "rgb(68, 190, 112)"],
    [0.8, "rgb(121, 209, 81)"],
    [0.9, "rgb(189, 222, 38)"],
    [1.0, "rgb(253, 231, 36)"],
]


[docs] def mpl_to_plotly(cmap, n_entries): h = 1.0 / (n_entries - 1) pl_colorscale = [] for k in range(n_entries): C = list(map(np.uint8, np.array(cmap(k * h)[:3]) * 255)) pl_colorscale.append( [round(k * h, 2), "rgb" + str((C[0], C[1], C[2]))] ) # Python 2.7+ # pl_colorscale.append([round(k*h, 2), f'rgb({C[0]}, {C[1]}, {C[2]})']) # Python 3.6+ return pl_colorscale
[docs] @deprecated_alias(color_function="color_values") def plotlyviz( scomplex, colorscale=None, title="Kepler Mapper", graph_layout="kk", color_values=None, color_function_name=None, node_color_function="mean", dashboard=False, graph_data=False, factor_size=3, edge_linewidth=1.5, node_linecolor="rgb(200,200,200)", width=600, height=500, bgcolor="rgba(240, 240, 240, 0.95)", left=10, bottom=35, summary_height=300, summary_width=600, summary_left=20, summary_right=20, hist_left=25, hist_right=25, member_textbox_width=800, filename=None, ): """ Visualizations and dashboards for kmapper graphs using Plotly. This method is suitable for use in Jupyter notebooks. The generated FigureWidget can be updated (by performing a restyle or relayout). For example, let us add a title to the colorbar (the name of the color function, if any), and set the title font size. To perform these updates faster, Plotly 3.+ provides a context manager that batches up all data and layout updates: To display more info on the generated kmapper-graph, define two more FigureWidget(s): the global node distribution figure, and a dummy figure that displays info on the algorithms involved in getting the graph from data, as well as sklearn class instances. A FigureWidget has event listeners for hovering, clicking or selecting. Using the first one for `fw_graph` we define, via the function `hovering_widgets()`, widgets that display the node distribution, when the node is hovered over, and two textboxes for the cluster size and the member ids/labels of the hovered node members. Parameters ----------- scomplex: dict Simplicial complex is the output from the KeplerMapper `map` method. title: str Title of output graphic graph_layout: igraph layout; recommended 'kk' (kamada-kawai) or 'fr' (fruchterman-reingold) colorscale: Plotly colorscale(colormap) to color graph nodes dashboard: bool, default is False If true, display complete dashboard of node information graph_data: bool, default is False If true, display graph metadata factor_size: double, default is 3 a factor for the node size edge_linewidth : double, default is 1.5 node_linecolor: color str, default is "rgb(200,200,200)" width: int, default is 600, height: int, default is 500, bgcolor: color str, default is "rgba(240, 240, 240, 0.95)", left: int, default is 10, bottom: int, default is 35, summary_height: int, default is 300, summary_width: int, default is 600, summary_left: int, default is 20, summary_right: int, default is 20, hist_left: int, default is 25, hist_right: int, default is 25, member_textbox_width: int, default is 800, filename: str, default is None if filename is given, the graphic will be saved to that file. Returns --------- result: plotly.FigureWidget A FigureWidget that can be shown or editted. See the Plotly Demo notebook for examples of use. """ if not colorscale: colorscale = default_colorscale kmgraph, mapper_summary, n_color_distribution = get_mapper_graph( scomplex, colorscale=colorscale, color_values=color_values, color_function_name=color_function_name, node_color_function=node_color_function, ) annotation = get_kmgraph_meta(mapper_summary) plgraph_data = plotly_graph( kmgraph, graph_layout=graph_layout, colorscale=colorscale, factor_size=factor_size, edge_linewidth=edge_linewidth, node_linecolor=node_linecolor, ) layout = plot_layout( title=title, width=width, height=height, annotation_text=annotation, bgcolor=bgcolor, left=left, bottom=bottom, ) result = go.FigureWidget(data=plgraph_data, layout=layout) if color_function_name: with result.batch_update(): result.data[1].marker.colorbar.title = color_function_name result.data[1].marker.colorbar.titlefont.size = 10 if dashboard or graph_data: fw_hist = node_hist_fig(n_color_distribution, left=hist_left, right=hist_right) fw_summary = summary_fig( mapper_summary, width=summary_width, height=summary_height, left=summary_left, right=summary_right, ) fw_graph = result result = hovering_widgets( kmgraph, fw_graph, member_textbox_width=member_textbox_width ) if graph_data: result = ipw.VBox([fw_graph, ipw.HBox([fw_summary, fw_hist])]) if filename: pio.write_image(result, filename) return result
@deprecated_alias(color_function="color_values") def scomplex_to_graph( simplicial_complex, color_values, X, X_names, lens, lens_names, custom_tooltips, colorscale, node_color_function="mean", ): color_values = np.array(color_values) json_dict = {"nodes": [], "links": []} node_id_to_num = {} for i, (node_id, member_ids) in enumerate(simplicial_complex["nodes"].items()): node_id_to_num[node_id] = i projection_stats, cluster_stats, member_histogram = _tooltip_components( member_ids, X, X_names, lens, lens_names, color_values, i, colorscale ) node_color = _node_color_function(member_ids, color_values, node_color_function) if isinstance(node_color, np.ndarray): node_color = node_color.tolist() n = { "id": i, "name": node_id, "member_ids": member_ids, "color": node_color, "size": _size_node(member_ids), "cluster": cluster_stats, "distribution": member_histogram, "projection": projection_stats, "custom_tooltips": custom_tooltips, } json_dict["nodes"].append(n) for i, (node_id, linked_node_ids) in enumerate(simplicial_complex["links"].items()): for linked_node_id in linked_node_ids: lnk = { "source": node_id_to_num[node_id], "target": node_id_to_num[linked_node_id], } json_dict["links"].append(lnk) return json_dict @deprecated_alias(color_function="color_values") def get_mapper_graph( simplicial_complex, color_values=None, color_function_name=None, node_color_function="mean", colorscale=None, custom_tooltips=None, custom_meta=None, X=None, X_names=None, lens=None, lens_names=None, ): """Generate data for mapper graph visualization and annotation. Parameters ---------- simplicial_complex : dict Simplicial complex is the output from the KeplerMapper `map` method. Returns ------- the graph dictionary in a json representation, the mapper summary and the node_distribution Example ------- >>> kmgraph, mapper_summary, n_distribution = get_mapper_graph(simplicial_complex) """ if not colorscale: colorscale = default_colorscale if not len(simplicial_complex["nodes"]) > 0: raise Exception( "A mapper graph should have more than 0 nodes. This might be because your clustering algorithm might be too sensitive and be classifying all points as noise." ) if color_values is None: # If no color_values provided we color by row order in data set n_samples = ( np.max([i for s in simplicial_complex["nodes"].values() for i in s]) + 1 ) color_values = np.arange(n_samples) color_function_name = ["Row number"] color_values = _scale_color_values(color_values) if X_names is None: X_names = [] if lens_names is None: lens_names = [] json_graph = scomplex_to_graph( simplicial_complex, color_values, X, X_names, lens, lens_names, custom_tooltips, colorscale=colorscale, node_color_function=node_color_function, ) colorf_distribution = _graph_data_distribution( simplicial_complex, color_values, node_color_function, colorscale ) mapper_summary = _format_meta( simplicial_complex, color_function_name=color_function_name, node_color_function=node_color_function, custom_meta=custom_meta, ) return json_graph, mapper_summary, colorf_distribution def plotly_graph( kmgraph, graph_layout="kk", colorscale=None, showscale=True, factor_size=3, edge_linecolor="rgb(180,180,180)", edge_linewidth=1.5, node_linecolor="rgb(255,255,255)", node_linewidth=1.0, ): """Generate Plotly data structures that represent the mapper graph Parameters ---------- kmgraph: dict representing the mapper graph, returned by the function get_mapper_graph() graph_layout: igraph layout; recommended 'kk' (kamada-kawai) or 'fr' (fruchterman-reingold) colorscale: a Plotly colorscale(colormap) to color graph nodes showscale: boolean to display or not the colorbar factor_size: a factor for the node size Returns ------- The plotly traces (dicts) representing the graph edges and nodes """ if not colorscale: colorscale = default_colorscale # define an igraph.Graph instance of n_nodes n_nodes = len(kmgraph["nodes"]) if n_nodes == 0: raise ValueError("Your graph has 0 nodes") G = ig.Graph(n=n_nodes) links = [(e["source"], e["target"]) for e in kmgraph["links"]] G.add_edges(links) layt = G.layout(graph_layout) hover_text = [node["name"] for node in kmgraph["nodes"]] color_vals = [node["color"] for node in kmgraph["nodes"]] node_size = np.array( [factor_size * node["size"] for node in kmgraph["nodes"]], dtype=int ) Xn, Yn, Xe, Ye = _get_plotly_data(links, layt) edge_trace = dict( type="scatter", x=Xe, y=Ye, mode="lines", line=dict(color=edge_linecolor, width=edge_linewidth), hoverinfo="none", ) node_trace = dict( type="scatter", x=Xn, y=Yn, mode="markers", marker=dict( size=node_size.tolist(), color=color_vals, opacity=1.0, colorscale=colorscale, showscale=showscale, line=dict(color=node_linecolor, width=node_linewidth), colorbar=dict(thickness=20, ticklen=4, x=1.01, tickfont=dict(size=10)), ), text=hover_text, hoverinfo="text", ) return [edge_trace, node_trace] def get_kmgraph_meta(mapper_summary): """Extract info from mapper summary to be displayed below the graph plot""" d = mapper_summary["custom_meta"] meta = ( "<b>N_cubes:</b> " + str(d["n_cubes"]) + " <b>Perc_overlap:</b> " + str(d["perc_overlap"]) ) meta += ( "<br><b>Nodes:</b> " + str(mapper_summary["n_nodes"]) + " <b>Edges:</b> " + str(mapper_summary["n_edges"]) + " <b>Total samples:</b> " + str(mapper_summary["n_total"]) + " <b>Unique_samples:</b> " + str(mapper_summary["n_unique"]) ) return meta def plot_layout( title="TDA KMapper", width=600, height=600, bgcolor="rgba(255, 255, 255, 1)", annotation_text=None, annotation_x=0, annotation_y=-0.01, top=100, left=60, right=60, bottom=60, ): """Set the plotly layout Parameters ---------- width, height: integers setting width and height of plot window bgcolor: string, rgba or hex color code for the background color annotation_text: string meta data to be displayed annotation_x & annotation_y: The coordinates of the point where we insert the annotation; the negative sign for y coord points output that annotation is inserted below the plot """ pl_layout = dict( title=title, font=dict(size=12), showlegend=False, autosize=False, width=width, height=height, xaxis=dict(visible=False), yaxis=dict(visible=False), hovermode="closest", plot_bgcolor=bgcolor, margin=dict(t=top, b=bottom, l=left, r=right), ) if annotation_text is None: return pl_layout else: annotations = [ dict( showarrow=False, text=annotation_text, xref="paper", yref="paper", x=annotation_x, y=annotation_y, align="left", xanchor="left", yanchor="top", font=dict(size=12), ) ] pl_layout.update(annotations=annotations) return pl_layout def node_hist_fig( node_color_distribution, title="Graph Node Distribution", width=400, height=300, top=60, left=25, bottom=60, right=25, bgcolor="rgb(240,240,240)", y_gridcolor="white", ): """Define the plotly plot representing the node histogram Parameters ---------- node_color_distribution: list of dicts describing the _build_histogram width, height: integers - width and height of the histogram FigureWidget left, top, right, bottom: ints; number of pixels around the FigureWidget bgcolor: rgb of hex color code for the figure background color y_gridcolor: rgb of hex color code for the yaxis y_gridcolor Returns ------- FigureWidget object representing the histogram of the graph nodes """ text = [ "{perc}%".format(**locals()) for perc in [d["perc"] for d in node_color_distribution] ] pl_hist = go.Bar( y=[d["height"] for d in node_color_distribution], marker=dict(color=[d["color"] for d in node_color_distribution]), text=text, hoverinfo="y+text", ) hist_layout = dict( title=title, width=width, height=height, font=dict(size=12), xaxis=dict(showline=True, zeroline=False, showgrid=False, showticklabels=False), yaxis=dict(showline=False, gridcolor=y_gridcolor, tickfont=dict(size=10)), bargap=0.01, margin=dict(l=left, r=right, b=bottom, t=top), hovermode="x", plot_bgcolor=bgcolor, ) return go.FigureWidget(data=[pl_hist], layout=hist_layout) def summary_fig( mapper_summary, width=600, height=500, top=60, left=20, bottom=60, right=20, bgcolor="rgb(240,240,240)", ): """Define a dummy figure that displays info on the algorithms and sklearn class instances or methods used Returns a FigureWidget object representing the figure """ text = _text_mapper_summary(mapper_summary) data = [ dict( type="scatter", x=[0, width], y=[height, 0], mode="text", text=[text, ""], textposition="bottom right", hoverinfo="none", ) ] layout = dict( title="Algorithms and scikit-learn objects/methods", width=width, height=height, font=dict(size=12), xaxis=dict(visible=False), yaxis=dict(visible=False, range=[0, height + 5]), margin=dict(t=top, b=bottom, l=left, r=right), plot_bgcolor=bgcolor, ) return go.FigureWidget(data=data, layout=layout) def hovering_widgets( kmgraph, graph_fw, ctooltips=False, width=400, height=300, top=100, left=50, bgcolor="rgb(240,240,240)", y_gridcolor="white", member_textbox_width=200, ): """Defines the widgets that display the distribution of each node on hover and the members of each nodes Parameters ---------- kmgraph: the kepler-mapper graph dict returned by `get_mapper_graph()`` graph_fw: the FigureWidget representing the graph ctooltips: boolean; if True/False the node["custom_tooltips"]/"member_ids" are passed to member_textbox width, height, top refer to the figure size and position of the hovered node distribution Returns ------- a box containing the graph figure, the figure of the hovered node distribution, and the textboxes displaying the cluster size and member_ids or custom tooltips for hovered node members """ fnode = kmgraph["nodes"][0] fwc = node_hist_fig( fnode["distribution"], title="Cluster Member Distribution", width=width, height=height, top=top, left=left, bgcolor=bgcolor, y_gridcolor=y_gridcolor, ) clust_textbox = ipw.Text( value="{:d}".format(fnode["cluster"]["size"]), description="Cluster size:", disabled=False, continuous_update=True, ) clust_textbox.layout = dict(margin="10px 10px 10px 10px", width="200px") member_textbox = ipw.Textarea( value=", ".join(str(x) for x in fnode["member_ids"]) if not ctooltips else ", ".join(str(x) for x in fnode["custom_tooltips"]), description="Members:", disabled=False, continuous_update=True, ) member_textbox.layout = dict( margin="5px 5px 5px 10px", width=str(member_textbox_width) + "px" ) def do_on_hover(trace, points, state): if not points.point_inds: return ind = points.point_inds[0] # get the index of the hovered node node = kmgraph["nodes"][ind] # on hover do: with fwc.batch_update(): # update data in the cluster member histogr fwc.data[0].text = [ "{:.1f}%".format(d["perc"]) for d in node["distribution"] ] fwc.data[0].y = [d["height"] for d in node["distribution"]] fwc.data[0].marker.color = [d["color"] for d in node["distribution"]] clust_textbox.value = "{:d}".format(node["cluster"]["size"]) member_textbox.value = ( ", ".join(str(x) for x in node["member_ids"]) if not ctooltips else ", ".join(str(x) for x in node["custom_tooltips"]) ) trace = graph_fw.data[1] trace.on_hover(do_on_hover) return ipw.VBox([ipw.HBox([graph_fw, fwc]), clust_textbox, member_textbox]) def _get_plotly_data(E, coords): # E : the list of tuples representing the graph edges # coords: list of node coordinates assigned by igraph.Layout N = len(coords) Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes Xedges = [] Yedges = [] for e in E: Xedges.extend([coords[e[0]][0], coords[e[1]][0], None]) Yedges.extend([coords[e[0]][1], coords[e[1]][1], None]) return Xnodes, Ynodes, Xedges, Yedges def _text_mapper_summary(mapper_summary): d = mapper_summary["custom_meta"] text = "<br><b>Projection: </b>" + d["projection"] text += ( "<br><b>Clusterer: </b>" + d["clusterer"] + "<br><b>Scaler: </b>" + d["scaler"] ) if "color_function" in d.keys(): text += "<br><b>Color function: </b>" + d["color_function"] return text def _hover_format(member_ids, custom_tooltips, X, X_names, lens, lens_names): cluster_data = _format_cluster_statistics(member_ids, X, X_names) tooltip = "" custom_tooltips = ( custom_tooltips[member_ids] if custom_tooltips is not None else member_ids ) val_size = cluster_data["size"] tooltip += "{val_size}".format(**locals()) return tooltip