"""
Functions to draw maps.
"""
import seaborn as sns
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib import pyplot as plt
from matplotlib.ticker import FuncFormatter
from scipy.spatial.distance import cdist
import numpy as np
import pandas as pd
from scipy.stats import spearmanr
from evomap.mapping._regression import IsotonicRegression
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from cycler import cycler
import warnings
import statsmodels.api as sm
DEFAULT_BUBBLE_SIZE = 50
DEFAULT_FONT_SIZE = 12
title_fontdict_large = {'size': 20, 'family': 'sans-serif'}
title_fontdict = {'size': 18, 'family': 'sans-serif'}
text_fontdict = {'size': DEFAULT_FONT_SIZE, 'family': 'sans-serif'}
axis_label_fontdict = {'size': 12, 'family': 'sans-serif'}
[docs]
def init_params(custom_params=None):
"""
Initialize plot aesthetics.
"""
base_style = {
"axes.prop_cycle": cycler('color', ["#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
"#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
"#bcbd22", "#17becf"]),
"axes.linewidth": 1,
"axes.titlesize": 22,
"axes.labelsize": 16,
"font.family": 'sans-serif',
"axes.edgecolor": "black",
"xtick.labelsize": 12,
"ytick.labelsize": 12,
"xtick.major.size": 0.2,
"xtick.minor.size": 0.1,
"ytick.major.size": 0.2,
"ytick.minor.size": 0.1,
"axes.grid": False,
"grid.alpha": 0.3,
"grid.linewidth": 0.5,
"grid.linestyle": "--",
"grid.color": "black",
"savefig.facecolor": "w",
"savefig.transparent": False,
"savefig.bbox": "tight",
"savefig.format": "png"
}
if custom_params:
base_style.update(custom_params)
mpl.rcParams.update(base_style)
[docs]
def style_axes(ax, show_axes=True, show_box=True, show_grid=False, axes_at_origin=False):
ax.xaxis.set_visible(show_axes)
ax.yaxis.set_visible(show_axes)
ax.grid(show_grid and show_axes)
if show_box:
for spine in ax.spines.values():
spine.set_visible(True)
else:
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_visible(show_axes)
ax.spines['bottom'].set_visible(show_axes)
if axes_at_origin and show_axes:
ax.spines['left'].set_position('zero')
ax.spines['bottom'].set_position('zero')
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)
elif not axes_at_origin:
ax.spines['left'].set_position(('outward', 0))
ax.spines['bottom'].set_position(('outward', 0))
# Setting labels
if show_axes:
ax.set_xlabel("Dimension 1", fontdict=axis_label_fontdict)
ax.set_ylabel("Dimension 2", fontdict=axis_label_fontdict)
ax.tick_params(axis='x', which='both', bottom=show_axes, labelbottom=show_axes)
ax.tick_params(axis='y', which='both', left=show_axes, labelleft=show_axes)
# Set equal ticks on both axes
set_equal_ticks(ax)
# Explicitly setting aspect to 'auto' so the axis scales adjust dynamically to the data
ax.set_aspect('auto')
ax.autoscale_view()
[docs]
def set_equal_ticks(ax):
x_min, x_max = ax.get_xlim()
y_min, y_max = ax.get_ylim()
# Define the number of ticks you want
num_ticks = 5
# Create ticks based on the range of each axis
x_ticks = np.linspace(x_min, x_max, num_ticks)
y_ticks = np.linspace(y_min, y_max, num_ticks)
# Set these ticks to the axis
ax.set_xticks(x_ticks)
ax.set_yticks(y_ticks)
# Optionally, you can format the ticks to show a specific number of decimal places or in a particular format
ax.xaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: '{:.1f}'.format(x)))
ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y, _: '{:.1f}'.format(y)))
[docs]
def draw_map(X, label=None, color=None, size=None, inclusions=None, zoom_on_cluster=None, highlighted_labels=None,
show_box=True, show_grid=False, show_axes=True, axes_at_origin=False, show_legend=False,
cmap=None, filename=None, ax=None, fig_size=None,
title=None, rotate_labels=0, scatter_kws={}, fontdict=None, rcparams=None):
"""
Plot a scatter map with optional labels, coloring, and sizing.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Data points to plot. n_features should be 1 or 2.
label : array-like, optional
Labels for each data point.
color : array-like, optional
Colors or group identifiers for each data point. If None, all points will have the same color.
size : array-like, optional
Sizes for each data point. If None, a default size is used.
inclusions : array-like of bool, optional
Boolean array to select which points are included in the plot.
zoom_on_cluster : int or string, optional
Cluster identifier to zoom in on specific cluster data points.
highlighted_labels : list, optional
Labels to be highlighted on the plot.
show_box : bool, optional
If True, show a box around the plot. Default is True.
show_grid : bool, optional
If True, show grid lines on the plot. Default is False.
show_axes : bool, optional
If True, show the axes of the plot. Default is False.
axes_at_origin : bool, optional
If True, draw axes lines through the origin. Default is False.
show_legend : bool, optional
If True, display a legend on the plot. Default is False.
cmap : str or Colormap, optional
Colormap to use for coloring the points. If None, a default colormap is used.
filename : str, optional
Path to save the figure file. If None, the figure is not saved.
ax : matplotlib.axes.Axes, optional
Pre-existing axes for the plot. If None, a new figure and axes are created.
fig_size : tuple, optional
Size of the figure to create. Ignored if `ax` is not None.
title : str, optional
Title of the plot.
rotate_labels : int, optional
Angle to rotate the labels. Default is 0.
scatter_kws : dict, optional
Additional keyword arguments to pass to the scatter plot function.
fontdict : dict, optional
Font dictionary for the labels. If None, a default fontdict is used.
rcparams : dict, optional
Dictionary to update matplotlib's rcParams for customizing plots.
Returns
-------
matplotlib.figure.Figure
Only if `ax` is None, the figure containing the plot is returned.
"""
n_samples = len(X)
X = np.atleast_2d(X)
if X.shape[1] == 1:
X = np.hstack([X, np.zeros_like(X)])
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
cmap = plt.get_cmap('tab10')
df_data = pd.DataFrame(X, columns=['x', 'y'])
if color is None:
df_data['color'] = ['white'] * n_samples # Use the first color in the colormap for all points
edge_colors = ['black'] * n_samples
else:
color = np.asarray(color)
unique_colors = np.unique(color)
if len(unique_colors) > cmap.N:
warnings.warn("Number of color labels exceeds the number of available colors in the colormap; colors will repeat.")
color_indices = pd.Categorical(color, categories=unique_colors).codes
colors = cmap(color_indices % cmap.N) # Use modulo to avoid index out of bounds
df_data['color'] = [colors[i] for i in range(len(color_indices))] # Assign colors as a list of RGBA values
edge_colors = ['black'] * n_samples
if label is not None:
df_data['label'] = label
if size is not None:
df_data['size'] = size
else:
df_data['size'] = DEFAULT_BUBBLE_SIZE
if inclusions is not None:
df_data = df_data[inclusions == 1]
if zoom_on_cluster is not None:
df_data = df_data[df_data['color'] == zoom_on_cluster]
init_params(rcparams)
if ax is None:
fig, ax = plt.subplots(figsize=fig_size or (6,6))
return_fig = True
else:
return_fig = False
scatter_args = {'edgecolors': edge_colors, 'alpha': 0.75}
scatter_args.update(scatter_kws)
ax.scatter(df_data['x'], df_data['y'], c=df_data['color'], s=df_data['size'], **scatter_args)
if label is not None:
for _, row in df_data.iterrows():
if highlighted_labels is None or row['label'] in highlighted_labels:
ax.text(row['x'], row['y'], row['label'], fontdict=fontdict or text_fontdict, rotation=rotate_labels)
style_axes(ax, show_axes, show_box, show_grid, axes_at_origin)
if title:
ax.set_title(title, fontdict=title_fontdict)
if show_legend:
ax.legend(title='Cluster')
if filename:
plt.savefig(filename, dpi=300, format='png', bbox_inches='tight')
if not return_fig:
plt.close()
[docs]
def normalize_dhat(d_hat, n_samples):
""" Normalize dissimilarity predictions. """
return d_hat * np.sqrt((n_samples * (n_samples - 1) / 2) / np.sum(d_hat**2))
[docs]
def draw_shepard_diagram(X, D, ax=None, show_grid=False, show_rank_correlation=True):
"""
Draw a Shepard diagram of input dissimilarities vs map distances.
Parameters:
X (np.ndarray): Configuration of objects on the map, shape (n_samples, n_dims).
D (np.ndarray): Dissimilarity matrix, shape (n_samples, n_samples).
ax (matplotlib.axes.Axes, optional): Axes object to draw the diagram on.
show_grid (bool, optional): Whether to show grid lines on the plot.
show_rank_correlation (bool, optional): Whether to display the rank correlation coefficient.
"""
distances = cdist(X, X, metric='euclidean')
distances_flat = distances[np.tril_indices(len(distances), -1)]
disparities_flat = D[np.tril_indices(len(D), -1)]
# Fit isotonic regression to the flattened arrays
ir = IsotonicRegression()
disp_hat = ir.fit_transform(X=disparities_flat, y=distances_flat)
disp_hat = normalize_dhat(disp_hat, X.shape[0])
# Prepare data for plotting
df = pd.DataFrame({
'Disparities': disparities_flat,
'Distances': distances_flat,
'Fitted Distances': disp_hat
}).sort_values('Disparities')
if ax is None:
fig, ax = plt.subplots(figsize=(6,6))
# Plotting the original and fitted distances
ax.scatter(df['Disparities'], df['Distances'], color="darkblue", label="Original", alpha=0.6)
ax.plot(df['Disparities'], df['Fitted Distances'], color="orange", label="Fitted", marker='o', linestyle='-')
ax.set_xlabel('Input Dissimilarity', fontdict=axis_label_fontdict)
ax.set_ylabel('Map Distance', fontdict=axis_label_fontdict)
ax.legend()
if show_grid:
ax.grid(True)
# Display Spearman rank correlation if requested
if show_rank_correlation:
rank_corr = spearmanr(df['Disparities'], df['Distances'])[0]
ax.text(0.5, -0.15, f'Rank Correlation: {rank_corr:.2f}', transform=ax.transAxes,
ha='center', fontsize=14)
ax.set_ylim(0, df['Distances'].max() * 1.15)
ax.set_xlim(df['Disparities'].min() * 0.99, df['Disparities'].max() * 1.01)
ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: f'{x:.2f}'))
ax.yaxis.set_major_formatter(FuncFormatter(lambda y, _: f'{y:.2f}'))
return ax
[docs]
def prepare_transparencies(n_periods, start, end, final):
""" Prepare transparency values for the plotting periods. """
return np.linspace(start, end, n_periods-1).tolist() + [final]
[docs]
def draw_dynamic_map(X_t, color_t=None, size_t=None, incl_t=None, label = None, show_arrows=False,
show_last_positions_only=False, time_labels=None, show_axes = True, show_box = True,
axes_at_origin = False, show_grid = False, cmap = None, transparency_start=0.1,
transparency_end=0.4, transparency_final=.75, scatter_kws={}, **kwargs):
"""
Visualizes dynamic map data over multiple periods with options to show movement paths and adjust visual features.
Parameters
----------
X_t : list of ndarray
List of arrays containing coordinates for each period, where each array is of shape (n_samples, n_features).
color_t : list of ndarray, optional
List of arrays containing color or group identifiers for each period.
size_t : list of ndarray, optional
List of arrays containing sizes for each data point in each period.
incl_t : list of ndarray, optional
List of arrays indicating if a point should be included in the plot for each period.
show_arrows : bool, optional
If True, display arrows showing movement between periods. Default is False.
show_last_positions_only : bool, optional
If True, only the last period's positions are shown with arrows indicating the movements from prior periods. Default is False.
time_labels : list of str, optional
Labels for each period, displayed in plot annotations or titles.
transparency_start : float, optional
Starting transparency level for the first period in the dynamic map.
transparency_end : float, optional
Ending transparency level just before the last period in the dynamic map.
transparency_final : float, optional
Transparency level for the last period in the dynamic map.
**kwargs : dict
Additional keyword arguments to pass to the plotting function or for configuring plot aspects.
Returns
-------
matplotlib.figure.Figure
The figure containing the dynamic map, only if not plotted on an existing axis.
"""
n_periods = len(X_t)
n_samples = X_t[0].shape[0]
if isinstance(cmap, str):
cmap = plt.get_cmap(cmap)
else:
cmap = plt.get_cmap('tab10')
incl_t = [np.ones(n_samples, dtype=bool) if incl is None else incl for incl in incl_t] if incl_t else [np.ones(n_samples, dtype=bool)] * n_periods
ax = kwargs.get('ax')
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.get('fig_size', (6,6)))
return_fig = True
else:
fig = ax.figure
return_fig = False
for t in range(n_periods):
sizes = size_t[t] if size_t else np.full(n_samples, DEFAULT_BUBBLE_SIZE)
if not color_t is None:
unique_colors = np.unique(color_t[t])
if len(unique_colors) > cmap.N:
warnings.warn("Number of color labels exceeds the number of available colors in the colormap; colors will repeat.")
color_indices = pd.Categorical(color_t[t], categories=unique_colors).codes
colors = cmap(color_indices % cmap.N)
edge_colors = np.array(['black'] * n_samples) # Default to no edge color when using color mapping
else:
colors = np.array([cmap(0)] * n_samples) # Default to the first color in the cmap
edge_colors = np.array(['black'] * n_samples) # Black edges when default coloring
valid_indices = incl_t[t]
if t == n_periods - 1:
transparency = transparency_final # Use final transparency for the last period
else:
transparency = np.linspace(transparency_start, transparency_end, n_periods-1)[t]
scatter_args = {'edgecolors': edge_colors[valid_indices]}
scatter_args.update(scatter_kws)
# Only plot points for the last period if show_last_positions_only is True
if not show_last_positions_only or t == n_periods - 1:
ax.scatter(X_t[t][valid_indices, 0], X_t[t][valid_indices, 1], alpha=transparency, s=sizes[valid_indices],
c=colors[valid_indices], **scatter_args)
if show_arrows and t > 0:
for i in range(n_samples):
if incl_t[t][i] and incl_t[t-1][i]:
start_point = X_t[t-1][i]
end_point = X_t[t][i]
if not np.array_equal(start_point, end_point):
ax.plot([start_point[0], end_point[0]], [start_point[1], end_point[1]],
color='gray', alpha=transparency, linewidth=1)
if label is not None and t == n_periods - 1: # Check explicitly if label is not None
for i, txt in enumerate(label):
if valid_indices[i]:
ax.text(X_t[t][i, 0], X_t[t][i, 1], txt, fontsize=DEFAULT_FONT_SIZE)
style_axes(ax, show_axes, show_box, show_grid, axes_at_origin)
ax.axis('equal')
if 'filename' in kwargs:
plt.savefig(kwargs['filename'], dpi=300)
if not return_fig:
plt.close(fig)
[docs]
def plot_data(X, colors, sizes, inclusions, labels, transparency_start, transparency_end, transparency_final, period, n_periods, ax, show_last_positions_only):
"""
Helper function to manage data plotting for each period.
"""
transparency = np.linspace(transparency_start, transparency_end, n_periods)[period]
if show_last_positions_only and period == n_periods - 1:
transparency = transparency_final
valid_indices = inclusions
ax.scatter(X[valid_indices, 0], X[valid_indices, 1],
alpha=transparency, s=sizes if sizes is not None else DEFAULT_BUBBLE_SIZE,
color=colors[valid_indices])
if period == n_periods - 1 and labels is not None:
for i, txt in enumerate(labels):
if valid_indices[i]:
ax.text(X[i, 0], X[i, 1], txt, fontsize=DEFAULT_FONT_SIZE)
[docs]
def draw_trajectories(Y_ts, labels, selected_labels = None,
show_axes = False, show_box = True, show_grid = False, axes_at_origin = False,
annotate_periods = True, period_labels = None, fig_size = None, **kwargs):
""" Draw the trajectories of selected objects.
Parameters
----------
Y_ts : list of ndarrays, each of shape (n_samples, d)
Sequence of map coordinates.
labels : ndarray of shape (n_samples,)
Object labels (str)
selected_labels : ndarray of shape (n_selected,), optional
Selected object labels (str), by default None
title : str, optional
Figure title, by default None
annotate_periods : bool, optional
If true, labels for each period are shown next to each pair of map
coordinates, by default True
period_labels : ndarray of shape (n_periods,), optional
Period labels (str), by default None
ax : matplotlib.axes.Axes, optional
Pre-existing axes for the plot, by default None
figsize : tuple, optional
Figure size, by default (12,12)
"""
n_periods = len(Y_ts)
n_firms = Y_ts[0].shape[0]
if selected_labels == None:
selected_labels = labels
ax = kwargs.get('ax')
if ax is None:
fig, ax = plt.subplots(figsize=kwargs.get('fig_size', (6,6)))
return_fig = True
else:
fig = ax.figure
return_fig = False
annotations = []
if period_labels is None and annotate_periods == True:
period_labels = ["Period " + str(t+1) for t in range(n_periods)]
for i in range(n_firms):
if not labels[i] in selected_labels:
continue
xs = []
ys = []
# Plot the points
for t in range(n_periods):
alpha = 1 - (n_periods - t) / n_periods
alpha = alpha * .5
x = Y_ts[t][i,0]
y = Y_ts[t][i,1]
c = 'black'
c_line = 'grey'
label = labels[i]
ax.scatter(x,y , c = c, alpha = alpha)
xs.append(x)
ys.append(y)
if t == n_periods - 1:
label = ax.text(x ,y , label, c = c, alpha = .7, fontsize = DEFAULT_FONT_SIZE)
annotations.append(label)
elif annotate_periods:
label = ax.text(x ,y , period_labels[t], c = c_line, alpha = .5, fontsize = DEFAULT_FONT_SIZE * 0.8)
# texts.append(label)
# Plot the trajectory
ax.xaxis.set_ticklabels([])
ax.yaxis.set_ticklabels([])
ax.plot(xs, ys, c = c_line, alpha = .4)
style_axes(ax = ax, show_axes= show_axes, show_box = show_box, show_grid = show_grid, axes_at_origin = axes_at_origin)
if not return_fig:
plt.close(fig)
[docs]
def fit_attribute(coords, attribute_label, attribute_values, ax, coord_range):
"""Fit an attribute to the map and display the resultant vector."""
# Store the current limits to restore them later
x_lim = ax.get_xlim()
y_lim = ax.get_ylim()
X = coords
y = attribute_values
est = sm.OLS(y, sm.add_constant(X))
result = est.fit()
vector = result.params[1:] # Ignore intercept
r_squared = result.rsquared
length = ((x_lim[1] - x_lim[0]) / 2) * r_squared
ax.arrow(0, 0, vector[0] * length, vector[1] * length, linestyle='--', lw=.25, alpha=.75, width=.001, color='grey', head_width=0.1)
ax.text(vector[0] * 1.1 * length, vector[1] * 1.1 * length, attribute_label, fontdict={'size': 8, 'color': 'darkblue', 'weight': 'normal'})
# Reset the limits to their original values
ax.set_xlim(x_lim)
ax.set_ylim(y_lim)
[docs]
def fit_attributes(map_coords, df_attributes, ax):
"""Fit multiple attributes and display their vectors on the map."""
coord_range = (np.min(map_coords), np.max(map_coords)) # Get min and max for scaling
for attribute in df_attributes.columns:
fit_attribute(
coords=map_coords,
attribute_label=attribute,
attribute_values=df_attributes[attribute],
ax=ax,
coord_range=coord_range)
[docs]
def draw_map_sequence(X_t, color_t=None, size_t=None, incl_t=None, n_cols=4, time_labels=[],
show_axes=True, show_box=True, axes_at_origin=False, show_grid=False,
cmap=None, fontdict=None, scatter_kws={}, **kwargs):
"""
Draw a sequence of static maps next to each other.
Parameters
----------
X_t : list of ndarray
List of arrays containing map coordinates for each period.
color_t : list of ndarray, optional
Color for each data point per period.
size_t : list of ndarray, optional
Size for each data point per period.
incl_t : list of ndarray, optional
Inclusion array to filter points for each period.
n_cols : int, optional
Maximum number of maps shown in one row.
time_labels : list of str, optional
Labels for each period, displayed as titles.
show_axes : bool, optional
Show axes of the plot.
show_box : bool, optional
Show a box around the plot.
axes_at_origin : bool, optional
Draw axes lines through the origin.
show_grid : bool, optional
Show grid lines on the plot.
cmap : str or Colormap, optional
Colormap to use for coloring the points.
fontdict : dict, optional
Font dictionary for the labels.
scatter_kws : dict, optional
Additional keyword arguments to pass to the scatter plot function.
**kwargs : dict
Additional keyword arguments for figure creation and styling.
Returns
-------
matplotlib.figure.Figure
The figure containing the sequence of maps.
"""
n_periods = len(X_t)
n_rows = int(np.ceil(n_periods / n_cols))
fig_size = kwargs.get('fig_size', (4 * n_cols, 4 * n_rows))
fig, axs = plt.subplots(n_rows, n_cols, figsize=fig_size, constrained_layout=True)
axs = np.atleast_2d(axs) # Ensure axs is 2D array even if single row
if len(time_labels) == 0:
time_labels = ["Period " + str(t+1) for t in range(n_periods)]
for t in range(n_periods):
ax = axs[t // n_cols, t % n_cols]
map_kws = {
'X': X_t[t],
'ax': ax,
'label': None, # Adjust based on actual parameters
'color': color_t[t] if color_t else None,
'size': size_t[t] if size_t else None,
'inclusions': incl_t[t] if incl_t else None,
'show_axes': show_axes,
'show_box': show_box,
'axes_at_origin': axes_at_origin,
'show_grid': show_grid,
'cmap': cmap,
'fontdict': fontdict,
'scatter_kws': scatter_kws
}
draw_map(**map_kws)
if not time_labels is None:
ax.set_title(time_labels[t])
# Set all subplots to the same limits
ax.set_xlim(np.min([ax.get_xlim()[0] for ax in axs.flat]), np.max([ax.get_xlim()[1] for ax in axs.flat]))
ax.set_ylim(np.min([ax.get_ylim()[0] for ax in axs.flat]), np.max([ax.get_ylim()[1] for ax in axs.flat]))
# Hide unused axes
for idx in range(n_periods, n_rows * n_cols):
axs[idx // n_cols, idx % n_cols].axis('off')
return fig