Source code for kmapper.plotlyviz

from __future__ import division
from .utils import deprecated_alias

import numpy as np

from .visuals import (
    _scale_color_values,
    _size_node,
    _format_projection_statistics,
    _format_cluster_statistics,
    _node_color_function,
    _format_meta,
    _to_html_format,
    _map_val2color,
    _graph_data_distribution,
    _build_histogram,
    _tooltip_components,
)

try:
    import igraph as ig
    import plotly.graph_objs as go
    import ipywidgets as ipw
    import plotly.io as pio
except ImportError:
    print(
        """To use the plotly visualization tools, you must have the packages python-igraph, plotly, and ipywidgets installed in your environment."""
        """ It looks like at least one of these is missing.  Please install again with"""
        """\n\n\t`pip install python-igraph plotly ipywidgets`\n\nand try again"""
    )
    raise


default_colorscale = [
    [0.0, "rgb(68, 1, 84)"],  # Viridis
    [0.1, "rgb(72, 35, 116)"],
    [0.2, "rgb(64, 67, 135)"],
    [0.3, "rgb(52, 94, 141)"],
    [0.4, "rgb(41, 120, 142)"],
    [0.5, "rgb(32, 144, 140)"],
    [0.6, "rgb(34, 167, 132)"],
    [0.7, "rgb(68, 190, 112)"],
    [0.8, "rgb(121, 209, 81)"],
    [0.9, "rgb(189, 222, 38)"],
    [1.0, "rgb(253, 231, 36)"],
]


[docs]def mpl_to_plotly(cmap, n_entries): h = 1.0 / (n_entries - 1) pl_colorscale = [] for k in range(n_entries): C = list(map(np.uint8, np.array(cmap(k * h)[:3]) * 255)) pl_colorscale.append( [round(k * h, 2), "rgb" + str((C[0], C[1], C[2]))] ) # Python 2.7+ # pl_colorscale.append([round(k*h, 2), f'rgb({C[0]}, {C[1]}, {C[2]})']) # Python 3.6+ return pl_colorscale
[docs]@deprecated_alias(color_function="color_values") def plotlyviz( scomplex, colorscale=None, title="Kepler Mapper", graph_layout="kk", color_values=None, color_function_name=None, node_color_function="mean", dashboard=False, graph_data=False, factor_size=3, edge_linewidth=1.5, node_linecolor="rgb(200,200,200)", width=600, height=500, bgcolor="rgba(240, 240, 240, 0.95)", left=10, bottom=35, summary_height=300, summary_width=600, summary_left=20, summary_right=20, hist_left=25, hist_right=25, member_textbox_width=800, filename=None, ): """ Visualizations and dashboards for kmapper graphs using Plotly. This method is suitable for use in Jupyter notebooks. The generated FigureWidget can be updated (by performing a restyle or relayout). For example, let us add a title to the colorbar (the name of the color function, if any), and set the title font size. To perform these updates faster, Plotly 3.+ provides a context manager that batches up all data and layout updates: To display more info on the generated kmapper-graph, define two more FigureWidget(s): the global node distribution figure, and a dummy figure that displays info on the algorithms involved in getting the graph from data, as well as sklearn class instances. A FigureWidget has event listeners for hovering, clicking or selecting. Using the first one for `fw_graph` we define, via the function `hovering_widgets()`, widgets that display the node distribution, when the node is hovered over, and two textboxes for the cluster size and the member ids/labels of the hovered node members. Parameters ----------- scomplex: dict Simplicial complex is the output from the KeplerMapper `map` method. title: str Title of output graphic graph_layout: igraph layout; recommended 'kk' (kamada-kawai) or 'fr' (fruchterman-reingold) colorscale: Plotly colorscale(colormap) to color graph nodes dashboard: bool, default is False If true, display complete dashboard of node information graph_data: bool, default is False If true, display graph metadata factor_size: double, default is 3 a factor for the node size edge_linewidth : double, default is 1.5 node_linecolor: color str, default is "rgb(200,200,200)" width: int, default is 600, height: int, default is 500, bgcolor: color str, default is "rgba(240, 240, 240, 0.95)", left: int, default is 10, bottom: int, default is 35, summary_height: int, default is 300, summary_width: int, default is 600, summary_left: int, default is 20, summary_right: int, default is 20, hist_left: int, default is 25, hist_right: int, default is 25, member_textbox_width: int, default is 800, filename: str, default is None if filename is given, the graphic will be saved to that file. Returns --------- result: plotly.FigureWidget A FigureWidget that can be shown or editted. See the Plotly Demo notebook for examples of use. """ if not colorscale: colorscale = default_colorscale kmgraph, mapper_summary, n_color_distribution = get_mapper_graph( scomplex, colorscale=colorscale, color_values=color_values, color_function_name=color_function_name, node_color_function=node_color_function, ) annotation = get_kmgraph_meta(mapper_summary) plgraph_data = plotly_graph( kmgraph, graph_layout=graph_layout, colorscale=colorscale, factor_size=factor_size, edge_linewidth=edge_linewidth, node_linecolor=node_linecolor, ) layout = plot_layout( title=title, width=width, height=height, annotation_text=annotation, bgcolor=bgcolor, left=left, bottom=bottom, ) result = go.FigureWidget(data=plgraph_data, layout=layout) if color_function_name: with result.batch_update(): result.data[1].marker.colorbar.title = color_function_name result.data[1].marker.colorbar.titlefont.size = 10 if dashboard or graph_data: fw_hist = node_hist_fig(n_color_distribution, left=hist_left, right=hist_right) fw_summary = summary_fig( mapper_summary, width=summary_width, height=summary_height, left=summary_left, right=summary_right, ) fw_graph = result result = hovering_widgets( kmgraph, fw_graph, member_textbox_width=member_textbox_width ) if graph_data: result = ipw.VBox([fw_graph, ipw.HBox([fw_summary, fw_hist])]) if filename: pio.write_image(result, filename) return result
@deprecated_alias(color_function="color_values") def scomplex_to_graph( simplicial_complex, color_values, X, X_names, lens, lens_names, custom_tooltips, colorscale, node_color_function="mean", ): color_values = np.array(color_values) json_dict = {"nodes": [], "links": []} node_id_to_num = {} for i, (node_id, member_ids) in enumerate(simplicial_complex["nodes"].items()): node_id_to_num[node_id] = i projection_stats, cluster_stats, member_histogram = _tooltip_components( member_ids, X, X_names, lens, lens_names, color_values, i, colorscale ) node_color = _node_color_function(member_ids, color_values, node_color_function) if isinstance(node_color, np.ndarray): node_color = node_color.tolist() n = { "id": i, "name": node_id, "member_ids": member_ids, "color": node_color, "size": _size_node(member_ids), "cluster": cluster_stats, "distribution": member_histogram, "projection": projection_stats, "custom_tooltips": custom_tooltips, } json_dict["nodes"].append(n) for i, (node_id, linked_node_ids) in enumerate(simplicial_complex["links"].items()): for linked_node_id in linked_node_ids: lnk = { "source": node_id_to_num[node_id], "target": node_id_to_num[linked_node_id], } json_dict["links"].append(lnk) return json_dict @deprecated_alias(color_function="color_values") def get_mapper_graph( simplicial_complex, color_values=None, color_function_name=None, node_color_function="mean", colorscale=None, custom_tooltips=None, custom_meta=None, X=None, X_names=None, lens=None, lens_names=None, ): """Generate data for mapper graph visualization and annotation. Parameters ---------- simplicial_complex : dict Simplicial complex is the output from the KeplerMapper `map` method. Returns ------- the graph dictionary in a json representation, the mapper summary and the node_distribution Example ------- >>> kmgraph, mapper_summary, n_distribution = get_mapper_graph(simplicial_complex) """ if not colorscale: colorscale = default_colorscale if not len(simplicial_complex["nodes"]) > 0: raise Exception( "A mapper graph should have more than 0 nodes. This might be because your clustering algorithm might be too sensitive and be classifying all points as noise." ) if color_values is None: # If no color_values provided we color by row order in data set n_samples = ( np.max([i for s in simplicial_complex["nodes"].values() for i in s]) + 1 ) color_values = np.arange(n_samples) color_function_name = ["Row number"] color_values = _scale_color_values(color_values) if X_names is None: X_names = [] if lens_names is None: lens_names = [] json_graph = scomplex_to_graph( simplicial_complex, color_values, X, X_names, lens, lens_names, custom_tooltips, colorscale=colorscale, node_color_function=node_color_function, ) colorf_distribution = _graph_data_distribution( simplicial_complex, color_values, node_color_function, colorscale ) mapper_summary = _format_meta( simplicial_complex, color_function_name=color_function_name, node_color_function=node_color_function, custom_meta=custom_meta, ) return json_graph, mapper_summary, colorf_distribution def plotly_graph( kmgraph, graph_layout="kk", colorscale=None, showscale=True, factor_size=3, edge_linecolor="rgb(180,180,180)", edge_linewidth=1.5, node_linecolor="rgb(255,255,255)", node_linewidth=1.0, ): """Generate Plotly data structures that represent the mapper graph Parameters ---------- kmgraph: dict representing the mapper graph, returned by the function get_mapper_graph() graph_layout: igraph layout; recommended 'kk' (kamada-kawai) or 'fr' (fruchterman-reingold) colorscale: a Plotly colorscale(colormap) to color graph nodes showscale: boolean to display or not the colorbar factor_size: a factor for the node size Returns ------- The plotly traces (dicts) representing the graph edges and nodes """ if not colorscale: colorscale = default_colorscale # define an igraph.Graph instance of n_nodes n_nodes = len(kmgraph["nodes"]) if n_nodes == 0: raise ValueError("Your graph has 0 nodes") G = ig.Graph(n=n_nodes) links = [(e["source"], e["target"]) for e in kmgraph["links"]] G.add_edges(links) layt = G.layout(graph_layout) hover_text = [node["name"] for node in kmgraph["nodes"]] color_vals = [node["color"] for node in kmgraph["nodes"]] node_size = np.array( [factor_size * node["size"] for node in kmgraph["nodes"]], dtype=int ) Xn, Yn, Xe, Ye = _get_plotly_data(links, layt) edge_trace = dict( type="scatter", x=Xe, y=Ye, mode="lines", line=dict(color=edge_linecolor, width=edge_linewidth), hoverinfo="none", ) node_trace = dict( type="scatter", x=Xn, y=Yn, mode="markers", marker=dict( size=node_size.tolist(), color=color_vals, opacity=1.0, colorscale=colorscale, showscale=showscale, line=dict(color=node_linecolor, width=node_linewidth), colorbar=dict(thickness=20, ticklen=4, x=1.01, tickfont=dict(size=10)), ), text=hover_text, hoverinfo="text", ) return [edge_trace, node_trace] def get_kmgraph_meta(mapper_summary): """Extract info from mapper summary to be displayed below the graph plot""" d = mapper_summary["custom_meta"] meta = ( "<b>N_cubes:</b> " + str(d["n_cubes"]) + " <b>Perc_overlap:</b> " + str(d["perc_overlap"]) ) meta += ( "<br><b>Nodes:</b> " + str(mapper_summary["n_nodes"]) + " <b>Edges:</b> " + str(mapper_summary["n_edges"]) + " <b>Total samples:</b> " + str(mapper_summary["n_total"]) + " <b>Unique_samples:</b> " + str(mapper_summary["n_unique"]) ) return meta def plot_layout( title="TDA KMapper", width=600, height=600, bgcolor="rgba(255, 255, 255, 1)", annotation_text=None, annotation_x=0, annotation_y=-0.01, top=100, left=60, right=60, bottom=60, ): """Set the plotly layout Parameters ---------- width, height: integers setting width and height of plot window bgcolor: string, rgba or hex color code for the background color annotation_text: string meta data to be displayed annotation_x & annotation_y: The coordinates of the point where we insert the annotation; the negative sign for y coord points output that annotation is inserted below the plot """ pl_layout = dict( title=title, font=dict(size=12), showlegend=False, autosize=False, width=width, height=height, xaxis=dict(visible=False), yaxis=dict(visible=False), hovermode="closest", plot_bgcolor=bgcolor, margin=dict(t=top, b=bottom, l=left, r=right), ) if annotation_text is None: return pl_layout else: annotations = [ dict( showarrow=False, text=annotation_text, xref="paper", yref="paper", x=annotation_x, y=annotation_y, align="left", xanchor="left", yanchor="top", font=dict(size=12), ) ] pl_layout.update(annotations=annotations) return pl_layout def node_hist_fig( node_color_distribution, title="Graph Node Distribution", width=400, height=300, top=60, left=25, bottom=60, right=25, bgcolor="rgb(240,240,240)", y_gridcolor="white", ): """Define the plotly plot representing the node histogram Parameters ---------- node_color_distribution: list of dicts describing the _build_histogram width, height: integers - width and height of the histogram FigureWidget left, top, right, bottom: ints; number of pixels around the FigureWidget bgcolor: rgb of hex color code for the figure background color y_gridcolor: rgb of hex color code for the yaxis y_gridcolor Returns ------- FigureWidget object representing the histogram of the graph nodes """ text = [ "{perc}%".format(**locals()) for perc in [d["perc"] for d in node_color_distribution] ] pl_hist = go.Bar( y=[d["height"] for d in node_color_distribution], marker=dict(color=[d["color"] for d in node_color_distribution]), text=text, hoverinfo="y+text", ) hist_layout = dict( title=title, width=width, height=height, font=dict(size=12), xaxis=dict(showline=True, zeroline=False, showgrid=False, showticklabels=False), yaxis=dict(showline=False, gridcolor=y_gridcolor, tickfont=dict(size=10)), bargap=0.01, margin=dict(l=left, r=right, b=bottom, t=top), hovermode="x", plot_bgcolor=bgcolor, ) return go.FigureWidget(data=[pl_hist], layout=hist_layout) def summary_fig( mapper_summary, width=600, height=500, top=60, left=20, bottom=60, right=20, bgcolor="rgb(240,240,240)", ): """Define a dummy figure that displays info on the algorithms and sklearn class instances or methods used Returns a FigureWidget object representing the figure """ text = _text_mapper_summary(mapper_summary) data = [ dict( type="scatter", x=[0, width], y=[height, 0], mode="text", text=[text, ""], textposition="bottom right", hoverinfo="none", ) ] layout = dict( title="Algorithms and scikit-learn objects/methods", width=width, height=height, font=dict(size=12), xaxis=dict(visible=False), yaxis=dict(visible=False, range=[0, height + 5]), margin=dict(t=top, b=bottom, l=left, r=right), plot_bgcolor=bgcolor, ) return go.FigureWidget(data=data, layout=layout) def hovering_widgets( kmgraph, graph_fw, ctooltips=False, width=400, height=300, top=100, left=50, bgcolor="rgb(240,240,240)", y_gridcolor="white", member_textbox_width=200, ): """Defines the widgets that display the distribution of each node on hover and the members of each nodes Parameters ---------- kmgraph: the kepler-mapper graph dict returned by `get_mapper_graph()`` graph_fw: the FigureWidget representing the graph ctooltips: boolean; if True/False the node["custom_tooltips"]/"member_ids" are passed to member_textbox width, height, top refer to the figure size and position of the hovered node distribution Returns ------- a box containing the graph figure, the figure of the hovered node distribution, and the textboxes displaying the cluster size and member_ids or custom tooltips for hovered node members """ fnode = kmgraph["nodes"][0] fwc = node_hist_fig( fnode["distribution"], title="Cluster Member Distribution", width=width, height=height, top=top, left=left, bgcolor=bgcolor, y_gridcolor=y_gridcolor, ) clust_textbox = ipw.Text( value="{:d}".format(fnode["cluster"]["size"]), description="Cluster size:", disabled=False, continuous_update=True, ) clust_textbox.layout = dict(margin="10px 10px 10px 10px", width="200px") member_textbox = ipw.Textarea( value=", ".join(str(x) for x in fnode["member_ids"]) if not ctooltips else ", ".join(str(x) for x in fnode["custom_tooltips"]), description="Members:", disabled=False, continuous_update=True, ) member_textbox.layout = dict( margin="5px 5px 5px 10px", width=str(member_textbox_width) + "px" ) def do_on_hover(trace, points, state): if not points.point_inds: return ind = points.point_inds[0] # get the index of the hovered node node = kmgraph["nodes"][ind] # on hover do: with fwc.batch_update(): # update data in the cluster member histogr fwc.data[0].text = [ "{:.1f}%".format(d["perc"]) for d in node["distribution"] ] fwc.data[0].y = [d["height"] for d in node["distribution"]] fwc.data[0].marker.color = [d["color"] for d in node["distribution"]] clust_textbox.value = "{:d}".format(node["cluster"]["size"]) member_textbox.value = ( ", ".join(str(x) for x in node["member_ids"]) if not ctooltips else ", ".join(str(x) for x in node["custom_tooltips"]) ) trace = graph_fw.data[1] trace.on_hover(do_on_hover) return ipw.VBox([ipw.HBox([graph_fw, fwc]), clust_textbox, member_textbox]) def _get_plotly_data(E, coords): # E : the list of tuples representing the graph edges # coords: list of node coordinates assigned by igraph.Layout N = len(coords) Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes Xedges = [] Yedges = [] for e in E: Xedges.extend([coords[e[0]][0], coords[e[1]][0], None]) Yedges.extend([coords[e[0]][1], coords[e[1]][1], None]) return Xnodes, Ynodes, Xedges, Yedges def _text_mapper_summary(mapper_summary): d = mapper_summary["custom_meta"] text = "<br><b>Projection: </b>" + d["projection"] text += ( "<br><b>Clusterer: </b>" + d["clusterer"] + "<br><b>Scaler: </b>" + d["scaler"] ) if "color_function" in d.keys(): text += "<br><b>Color function: </b>" + d["color_function"] return text def _hover_format(member_ids, custom_tooltips, X, X_names, lens, lens_names): cluster_data = _format_cluster_statistics(member_ids, X, X_names) tooltip = "" custom_tooltips = ( custom_tooltips[member_ids] if custom_tooltips is not None else member_ids ) val_size = cluster_data["size"] tooltip += "{val_size}".format(**locals()) return tooltip