Source code for kmapper.drawing
"""
Methods for drawing graphs
"""
import numpy as np
__all__ = ["draw_matplotlib"]
[docs]def draw_matplotlib(g, ax=None, fig=None, layout="kk"):
"""Draw the graph using NetworkX drawing functionality.
Parameters
------------
g: graph object returned by ``map``
The Mapper graph as constructed by ``KeplerMapper.map``
ax: matplotlib Axes object
A matplotlib axes object to plot graph on. If none, then use ``plt.gca()``
fig: matplotlib Figure object
A matplotlib Figure object to plot graph on. If none, then use ``plt.figure()``
layout: string
Key for which of NetworkX's layout functions.
Key options implemented are:
::
>>> "kk": nx.kamada_kawai_layout,
>>> "spring": nx.spring_layout,
>>> "bi": nx.bipartite_layout,
>>> "circ": nx.circular_layout,
>>> "spect": nx.spectral_layout
Returns
--------
nodes: nx node set object list
List of nodes constructed with Networkx ``draw_networkx_nodes``. This can be used to further customize node attributes.
"""
import networkx as nx
import os
# https://stackoverflow.com/a/50089385/5917194
import matplotlib as mpl
if os.environ.get("DISPLAY", "") == "":
print("no display found. Using non-interactive Agg backend")
mpl.use("Agg")
import matplotlib.pyplot as plt
fig = fig if fig else plt.figure()
ax = ax if ax else plt.gca()
if not isinstance(g, nx.Graph):
from .adapter import to_networkx
g = to_networkx(g)
# Determine a fine size for nodes
bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
width, height = bbox.width, bbox.height
area = width * height * fig.dpi
n_nodes = len(g.nodes)
# size of node should be related to area and number of nodes -- heuristic
node_size = np.pi * area / n_nodes
node_r = np.sqrt(node_size / np.pi)
node_edge = node_r / 3
layouts = {
"kk": nx.kamada_kawai_layout,
"spring": nx.spring_layout,
"bi": nx.bipartite_layout,
"circ": nx.circular_layout,
"spect": nx.spectral_layout,
}
pos = layouts[layout](g)
nodes = nx.draw_networkx_nodes(g, node_size=node_size, pos=pos, ax=ax)
edges = nx.draw_networkx_edges(g, pos=pos, ax=ax)
nodes.set_edgecolor("w")
nodes.set_linewidth(node_edge)
ax.axis("square")
ax.axis("off")
return nodes