from typing import Optional
import matplotlib.pyplot as plt
import numpy as np
[docs]
def _determine_vmax(max_data_value):
vmax = 1
if max_data_value > 255:
vmax = None
elif max_data_value > 1:
vmax = 255
return vmax
[docs]
def plot_image(heatmap: np.ndarray,
original_data: Optional[np.ndarray] = None,
heatmap_cmap='bwr',
heatmap_range=(None, None), # (vmin, vmax)
data_cmap=None,
show_plot: bool = True,
output_filename=None,
ax: Optional[plt.Axes] = None,
) -> plt.Figure:
"""Plots a heatmap image.
Optionally, the heatmap (typically a saliency map of an explainer) can be
plotted on top of the original data. In that case both images are plotted
transparantly with alpha = 0.5.
Args:
heatmap: the saliency map or other heatmap to be plotted.
original_data: the data to plot together with the heatmap, both with
alpha = 0.5 (optional).
heatmap_cmap: color map for the heatmap plot (see mpl.Axes.imshow
documentation for options).
heatmap_range: a tuple (vmin, vmax) to set the range of the heatmap.
By default, the colormap covers the complete value range of
the supplied heatmap.
data_cmap: color map for the (optional) data image (see mpl.Axes.imshow
documentation for options). By default, if the image is two
dimensional, the color map is set to 'gray'.
show_plot: Shows plot if true (for testing or writing plots to disk
instead).
output_filename: Name of the file to save the plot to (optional).
ax: matplotlib.Axes object to plot on (optional).
Returns:
None
"""
# default cmap depends on shape: grayscale or colour
if ax is None:
fig, ax = plt.subplots()
else:
fig = ax.get_figure()
alpha = 1
if original_data is not None:
if len(original_data.shape) == 2 and data_cmap is None:
# 2D array, grayscale
data_cmap = 'gray'
ax.imshow(original_data,
cmap=data_cmap,
vmin=0,
vmax=_determine_vmax(original_data.max()))
alpha = .5
vmin, vmax = heatmap_range
cax = ax.imshow(heatmap,
vmin=vmin,
vmax=vmax,
cmap=heatmap_cmap,
alpha=alpha)
fig.colorbar(cax)
ax.tick_params(bottom=False,
left=False,
right=False,
top=False,
labelleft=False,
labelbottom=False,
labelright=False,
labeltop=False)
if not show_plot:
plt.close()
if output_filename:
plt.savefig(output_filename)
return fig, ax