from __future__ import division
from .utils import deprecated_alias
import numpy as np
from .visuals import (
_scale_color_values,
_size_node,
_format_projection_statistics,
_format_cluster_statistics,
_node_color_function,
_format_meta,
_to_html_format,
_map_val2color,
_graph_data_distribution,
_build_histogram,
_tooltip_components,
)
try:
import igraph as ig
import plotly.graph_objs as go
import ipywidgets as ipw
import plotly.io as pio
except ImportError:
print(
"""To use the plotly visualization tools, you must have the packages python-igraph, plotly, and ipywidgets installed in your environment."""
""" It looks like at least one of these is missing. Please install again with"""
"""\n\n\t`pip install python-igraph plotly ipywidgets`\n\nand try again"""
)
raise
default_colorscale = [
[0.0, "rgb(68, 1, 84)"], # Viridis
[0.1, "rgb(72, 35, 116)"],
[0.2, "rgb(64, 67, 135)"],
[0.3, "rgb(52, 94, 141)"],
[0.4, "rgb(41, 120, 142)"],
[0.5, "rgb(32, 144, 140)"],
[0.6, "rgb(34, 167, 132)"],
[0.7, "rgb(68, 190, 112)"],
[0.8, "rgb(121, 209, 81)"],
[0.9, "rgb(189, 222, 38)"],
[1.0, "rgb(253, 231, 36)"],
]
[docs]def mpl_to_plotly(cmap, n_entries):
h = 1.0 / (n_entries - 1)
pl_colorscale = []
for k in range(n_entries):
C = list(map(np.uint8, np.array(cmap(k * h)[:3]) * 255))
pl_colorscale.append(
[round(k * h, 2), "rgb" + str((C[0], C[1], C[2]))]
) # Python 2.7+
# pl_colorscale.append([round(k*h, 2), f'rgb({C[0]}, {C[1]}, {C[2]})']) # Python 3.6+
return pl_colorscale
[docs]@deprecated_alias(color_function="color_values")
def plotlyviz(
scomplex,
colorscale=None,
title="Kepler Mapper",
graph_layout="kk",
color_values=None,
color_function_name=None,
node_color_function="mean",
dashboard=False,
graph_data=False,
factor_size=3,
edge_linewidth=1.5,
node_linecolor="rgb(200,200,200)",
width=600,
height=500,
bgcolor="rgba(240, 240, 240, 0.95)",
left=10,
bottom=35,
summary_height=300,
summary_width=600,
summary_left=20,
summary_right=20,
hist_left=25,
hist_right=25,
member_textbox_width=800,
filename=None,
):
"""
Visualizations and dashboards for kmapper graphs using Plotly. This method is suitable for use in Jupyter notebooks.
The generated FigureWidget can be updated (by performing a restyle or relayout). For example, let us add a title
to the colorbar (the name of the color function, if any),
and set the title font size. To perform these updates faster, Plotly 3.+ provides a context manager that batches up all data and layout updates:
To display more info on the generated kmapper-graph, define two more FigureWidget(s):
the global node distribution figure, and a dummy figure
that displays info on the algorithms involved in getting the graph from data, as well as sklearn class instances.
A FigureWidget has event listeners for hovering, clicking or selecting. Using the first one for `fw_graph`
we define, via the function `hovering_widgets()`, widgets that display the node distribution, when the node is hovered over, and two textboxes for the cluster size and the member ids/labels of the hovered node members.
Parameters
-----------
scomplex: dict
Simplicial complex is the output from the KeplerMapper `map` method.
title: str
Title of output graphic
graph_layout: igraph layout;
recommended 'kk' (kamada-kawai) or 'fr' (fruchterman-reingold)
colorscale:
Plotly colorscale(colormap) to color graph nodes
dashboard: bool, default is False
If true, display complete dashboard of node information
graph_data: bool, default is False
If true, display graph metadata
factor_size: double, default is 3
a factor for the node size
edge_linewidth : double, default is 1.5
node_linecolor: color str, default is "rgb(200,200,200)"
width: int, default is 600,
height: int, default is 500,
bgcolor: color str, default is "rgba(240, 240, 240, 0.95)",
left: int, default is 10,
bottom: int, default is 35,
summary_height: int, default is 300,
summary_width: int, default is 600,
summary_left: int, default is 20,
summary_right: int, default is 20,
hist_left: int, default is 25,
hist_right: int, default is 25,
member_textbox_width: int, default is 800,
filename: str, default is None
if filename is given, the graphic will be saved to that file.
Returns
---------
result: plotly.FigureWidget
A FigureWidget that can be shown or editted. See the Plotly Demo notebook for examples of use.
"""
if not colorscale:
colorscale = default_colorscale
kmgraph, mapper_summary, n_color_distribution = get_mapper_graph(
scomplex,
colorscale=colorscale,
color_values=color_values,
color_function_name=color_function_name,
node_color_function=node_color_function,
)
annotation = get_kmgraph_meta(mapper_summary)
plgraph_data = plotly_graph(
kmgraph,
graph_layout=graph_layout,
colorscale=colorscale,
factor_size=factor_size,
edge_linewidth=edge_linewidth,
node_linecolor=node_linecolor,
)
layout = plot_layout(
title=title,
width=width,
height=height,
annotation_text=annotation,
bgcolor=bgcolor,
left=left,
bottom=bottom,
)
result = go.FigureWidget(data=plgraph_data, layout=layout)
if color_function_name:
with result.batch_update():
result.data[1].marker.colorbar.title = color_function_name
result.data[1].marker.colorbar.titlefont.size = 10
if dashboard or graph_data:
fw_hist = node_hist_fig(n_color_distribution, left=hist_left, right=hist_right)
fw_summary = summary_fig(
mapper_summary,
width=summary_width,
height=summary_height,
left=summary_left,
right=summary_right,
)
fw_graph = result
result = hovering_widgets(
kmgraph, fw_graph, member_textbox_width=member_textbox_width
)
if graph_data:
result = ipw.VBox([fw_graph, ipw.HBox([fw_summary, fw_hist])])
if filename:
pio.write_image(result, filename)
return result
@deprecated_alias(color_function="color_values")
def scomplex_to_graph(
simplicial_complex,
color_values,
X,
X_names,
lens,
lens_names,
custom_tooltips,
colorscale,
node_color_function="mean",
):
color_values = np.array(color_values)
json_dict = {"nodes": [], "links": []}
node_id_to_num = {}
for i, (node_id, member_ids) in enumerate(simplicial_complex["nodes"].items()):
node_id_to_num[node_id] = i
projection_stats, cluster_stats, member_histogram = _tooltip_components(
member_ids, X, X_names, lens, lens_names, color_values, i, colorscale
)
node_color = _node_color_function(member_ids, color_values, node_color_function)
if isinstance(node_color, np.ndarray):
node_color = node_color.tolist()
n = {
"id": i,
"name": node_id,
"member_ids": member_ids,
"color": node_color,
"size": _size_node(member_ids),
"cluster": cluster_stats,
"distribution": member_histogram,
"projection": projection_stats,
"custom_tooltips": custom_tooltips,
}
json_dict["nodes"].append(n)
for i, (node_id, linked_node_ids) in enumerate(simplicial_complex["links"].items()):
for linked_node_id in linked_node_ids:
lnk = {
"source": node_id_to_num[node_id],
"target": node_id_to_num[linked_node_id],
}
json_dict["links"].append(lnk)
return json_dict
@deprecated_alias(color_function="color_values")
def get_mapper_graph(
simplicial_complex,
color_values=None,
color_function_name=None,
node_color_function="mean",
colorscale=None,
custom_tooltips=None,
custom_meta=None,
X=None,
X_names=None,
lens=None,
lens_names=None,
):
"""Generate data for mapper graph visualization and annotation.
Parameters
----------
simplicial_complex : dict
Simplicial complex is the output from the KeplerMapper `map` method.
Returns
-------
the graph dictionary in a json representation, the mapper summary
and the node_distribution
Example
-------
>>> kmgraph, mapper_summary, n_distribution = get_mapper_graph(simplicial_complex)
"""
if not colorscale:
colorscale = default_colorscale
if not len(simplicial_complex["nodes"]) > 0:
raise Exception(
"A mapper graph should have more than 0 nodes. This might be because your clustering algorithm might be too sensitive and be classifying all points as noise."
)
if color_values is None:
# If no color_values provided we color by row order in data set
n_samples = (
np.max([i for s in simplicial_complex["nodes"].values() for i in s]) + 1
)
color_values = np.arange(n_samples)
color_function_name = ["Row number"]
color_values = _scale_color_values(color_values)
if X_names is None:
X_names = []
if lens_names is None:
lens_names = []
json_graph = scomplex_to_graph(
simplicial_complex,
color_values,
X,
X_names,
lens,
lens_names,
custom_tooltips,
colorscale=colorscale,
node_color_function=node_color_function,
)
colorf_distribution = _graph_data_distribution(
simplicial_complex, color_values, node_color_function, colorscale
)
mapper_summary = _format_meta(
simplicial_complex,
color_function_name=color_function_name,
node_color_function=node_color_function,
custom_meta=custom_meta,
)
return json_graph, mapper_summary, colorf_distribution
def plotly_graph(
kmgraph,
graph_layout="kk",
colorscale=None,
showscale=True,
factor_size=3,
edge_linecolor="rgb(180,180,180)",
edge_linewidth=1.5,
node_linecolor="rgb(255,255,255)",
node_linewidth=1.0,
):
"""Generate Plotly data structures that represent the mapper graph
Parameters
----------
kmgraph: dict representing the mapper graph,
returned by the function get_mapper_graph()
graph_layout: igraph layout; recommended 'kk' (kamada-kawai)
or 'fr' (fruchterman-reingold)
colorscale: a Plotly colorscale(colormap) to color graph nodes
showscale: boolean to display or not the colorbar
factor_size: a factor for the node size
Returns
-------
The plotly traces (dicts) representing the graph edges and nodes
"""
if not colorscale:
colorscale = default_colorscale
# define an igraph.Graph instance of n_nodes
n_nodes = len(kmgraph["nodes"])
if n_nodes == 0:
raise ValueError("Your graph has 0 nodes")
G = ig.Graph(n=n_nodes)
links = [(e["source"], e["target"]) for e in kmgraph["links"]]
G.add_edges(links)
layt = G.layout(graph_layout)
hover_text = [node["name"] for node in kmgraph["nodes"]]
color_vals = [node["color"] for node in kmgraph["nodes"]]
node_size = np.array(
[factor_size * node["size"] for node in kmgraph["nodes"]], dtype=int
)
Xn, Yn, Xe, Ye = _get_plotly_data(links, layt)
edge_trace = dict(
type="scatter",
x=Xe,
y=Ye,
mode="lines",
line=dict(color=edge_linecolor, width=edge_linewidth),
hoverinfo="none",
)
node_trace = dict(
type="scatter",
x=Xn,
y=Yn,
mode="markers",
marker=dict(
size=node_size.tolist(),
color=color_vals,
opacity=1.0,
colorscale=colorscale,
showscale=showscale,
line=dict(color=node_linecolor, width=node_linewidth),
colorbar=dict(thickness=20, ticklen=4, x=1.01, tickfont=dict(size=10)),
),
text=hover_text,
hoverinfo="text",
)
return [edge_trace, node_trace]
def get_kmgraph_meta(mapper_summary):
"""Extract info from mapper summary to be displayed below the graph plot"""
d = mapper_summary["custom_meta"]
meta = (
"<b>N_cubes:</b> "
+ str(d["n_cubes"])
+ " <b>Perc_overlap:</b> "
+ str(d["perc_overlap"])
)
meta += (
"<br><b>Nodes:</b> "
+ str(mapper_summary["n_nodes"])
+ " <b>Edges:</b> "
+ str(mapper_summary["n_edges"])
+ " <b>Total samples:</b> "
+ str(mapper_summary["n_total"])
+ " <b>Unique_samples:</b> "
+ str(mapper_summary["n_unique"])
)
return meta
def plot_layout(
title="TDA KMapper",
width=600,
height=600,
bgcolor="rgba(255, 255, 255, 1)",
annotation_text=None,
annotation_x=0,
annotation_y=-0.01,
top=100,
left=60,
right=60,
bottom=60,
):
"""Set the plotly layout
Parameters
----------
width, height: integers
setting width and height of plot window
bgcolor: string,
rgba or hex color code for the background color
annotation_text: string
meta data to be displayed
annotation_x & annotation_y:
The coordinates of the point where we insert the annotation; the negative sign for y coord points output that annotation is inserted below the plot
"""
pl_layout = dict(
title=title,
font=dict(size=12),
showlegend=False,
autosize=False,
width=width,
height=height,
xaxis=dict(visible=False),
yaxis=dict(visible=False),
hovermode="closest",
plot_bgcolor=bgcolor,
margin=dict(t=top, b=bottom, l=left, r=right),
)
if annotation_text is None:
return pl_layout
else:
annotations = [
dict(
showarrow=False,
text=annotation_text,
xref="paper",
yref="paper",
x=annotation_x,
y=annotation_y,
align="left",
xanchor="left",
yanchor="top",
font=dict(size=12),
)
]
pl_layout.update(annotations=annotations)
return pl_layout
def node_hist_fig(
node_color_distribution,
title="Graph Node Distribution",
width=400,
height=300,
top=60,
left=25,
bottom=60,
right=25,
bgcolor="rgb(240,240,240)",
y_gridcolor="white",
):
"""Define the plotly plot representing the node histogram
Parameters
----------
node_color_distribution: list of dicts describing the _build_histogram
width, height: integers - width and height of the histogram FigureWidget
left, top, right, bottom: ints; number of pixels around the FigureWidget
bgcolor: rgb of hex color code for the figure background color
y_gridcolor: rgb of hex color code for the yaxis y_gridcolor
Returns
-------
FigureWidget object representing the histogram of the graph nodes
"""
text = [
"{perc}%".format(**locals())
for perc in [d["perc"] for d in node_color_distribution]
]
pl_hist = go.Bar(
y=[d["height"] for d in node_color_distribution],
marker=dict(color=[d["color"] for d in node_color_distribution]),
text=text,
hoverinfo="y+text",
)
hist_layout = dict(
title=title,
width=width,
height=height,
font=dict(size=12),
xaxis=dict(showline=True, zeroline=False, showgrid=False, showticklabels=False),
yaxis=dict(showline=False, gridcolor=y_gridcolor, tickfont=dict(size=10)),
bargap=0.01,
margin=dict(l=left, r=right, b=bottom, t=top),
hovermode="x",
plot_bgcolor=bgcolor,
)
return go.FigureWidget(data=[pl_hist], layout=hist_layout)
def summary_fig(
mapper_summary,
width=600,
height=500,
top=60,
left=20,
bottom=60,
right=20,
bgcolor="rgb(240,240,240)",
):
"""Define a dummy figure that displays info on the algorithms and
sklearn class instances or methods used
Returns a FigureWidget object representing the figure
"""
text = _text_mapper_summary(mapper_summary)
data = [
dict(
type="scatter",
x=[0, width],
y=[height, 0],
mode="text",
text=[text, ""],
textposition="bottom right",
hoverinfo="none",
)
]
layout = dict(
title="Algorithms and scikit-learn objects/methods",
width=width,
height=height,
font=dict(size=12),
xaxis=dict(visible=False),
yaxis=dict(visible=False, range=[0, height + 5]),
margin=dict(t=top, b=bottom, l=left, r=right),
plot_bgcolor=bgcolor,
)
return go.FigureWidget(data=data, layout=layout)
def hovering_widgets(
kmgraph,
graph_fw,
ctooltips=False,
width=400,
height=300,
top=100,
left=50,
bgcolor="rgb(240,240,240)",
y_gridcolor="white",
member_textbox_width=200,
):
"""Defines the widgets that display the distribution of each node on hover
and the members of each nodes
Parameters
----------
kmgraph: the kepler-mapper graph dict returned by `get_mapper_graph()``
graph_fw: the FigureWidget representing the graph
ctooltips: boolean; if True/False the node["custom_tooltips"]/"member_ids"
are passed to member_textbox
width, height, top refer to the figure
size and position of the hovered node distribution
Returns
-------
a box containing the graph figure, the figure of the hovered node
distribution, and the textboxes displaying the cluster size and member_ids
or custom tooltips for hovered node members
"""
fnode = kmgraph["nodes"][0]
fwc = node_hist_fig(
fnode["distribution"],
title="Cluster Member Distribution",
width=width,
height=height,
top=top,
left=left,
bgcolor=bgcolor,
y_gridcolor=y_gridcolor,
)
clust_textbox = ipw.Text(
value="{:d}".format(fnode["cluster"]["size"]),
description="Cluster size:",
disabled=False,
continuous_update=True,
)
clust_textbox.layout = dict(margin="10px 10px 10px 10px", width="200px")
member_textbox = ipw.Textarea(
value=", ".join(str(x) for x in fnode["member_ids"])
if not ctooltips
else ", ".join(str(x) for x in fnode["custom_tooltips"]),
description="Members:",
disabled=False,
continuous_update=True,
)
member_textbox.layout = dict(
margin="5px 5px 5px 10px", width=str(member_textbox_width) + "px"
)
def do_on_hover(trace, points, state):
if not points.point_inds:
return
ind = points.point_inds[0] # get the index of the hovered node
node = kmgraph["nodes"][ind]
# on hover do:
with fwc.batch_update(): # update data in the cluster member histogr
fwc.data[0].text = [
"{:.1f}%".format(d["perc"]) for d in node["distribution"]
]
fwc.data[0].y = [d["height"] for d in node["distribution"]]
fwc.data[0].marker.color = [d["color"] for d in node["distribution"]]
clust_textbox.value = "{:d}".format(node["cluster"]["size"])
member_textbox.value = (
", ".join(str(x) for x in node["member_ids"])
if not ctooltips
else ", ".join(str(x) for x in node["custom_tooltips"])
)
trace = graph_fw.data[1]
trace.on_hover(do_on_hover)
return ipw.VBox([ipw.HBox([graph_fw, fwc]), clust_textbox, member_textbox])
def _get_plotly_data(E, coords):
# E : the list of tuples representing the graph edges
# coords: list of node coordinates assigned by igraph.Layout
N = len(coords)
Xnodes = [coords[k][0] for k in range(N)] # x-coordinates of nodes
Ynodes = [coords[k][1] for k in range(N)] # y-coordnates of nodes
Xedges = []
Yedges = []
for e in E:
Xedges.extend([coords[e[0]][0], coords[e[1]][0], None])
Yedges.extend([coords[e[0]][1], coords[e[1]][1], None])
return Xnodes, Ynodes, Xedges, Yedges
def _text_mapper_summary(mapper_summary):
d = mapper_summary["custom_meta"]
text = "<br><b>Projection: </b>" + d["projection"]
text += (
"<br><b>Clusterer: </b>" + d["clusterer"] + "<br><b>Scaler: </b>" + d["scaler"]
)
if "color_function" in d.keys():
text += "<br><b>Color function: </b>" + d["color_function"]
return text
def _hover_format(member_ids, custom_tooltips, X, X_names, lens, lens_names):
cluster_data = _format_cluster_statistics(member_ids, X, X_names)
tooltip = ""
custom_tooltips = (
custom_tooltips[member_ids] if custom_tooltips is not None else member_ids
)
val_size = cluster_data["size"]
tooltip += "{val_size}".format(**locals())
return tooltip