dianna.visualization
Tools for visualization of model explanations.
Submodules
Package Contents
Functions
|
Plots a heatmap image. |
|
Plot feature importance with segments highlighted. |
|
Highlights a given text based on values in a given explanation object. |
|
Plot timeseries with segments highlighted. |
- dianna.visualization.plot_image(heatmap, original_data=None, heatmap_cmap=None, heatmap_range=(None, None), data_cmap=None, show_plot=True, output_filename=None)[source]
Plots a heatmap image.
Optionally, the heatmap (typically a saliency map of an explainer) can be plotted on top of the original data. In that case both images are plotted transparantly with alpha = 0.5.
- Parameters:
heatmap – the saliency map or other heatmap to be plotted.
original_data – the data to plot together with the heatmap, both with alpha = 0.5 (optional).
heatmap_cmap – color map for the heatmap plot (see mpl.Axes.imshow documentation for options).
heatmap_range – a tuple (vmin, vmax) to set the range of the heatmap. By default, the colormap covers the complete value range of the supplied heatmap.
data_cmap – color map for the (optional) data image (see mpl.Axes.imshow documentation for options). By default, if the image is two dimensional, the color map is set to ‘gray’.
show_plot – Shows plot if true (for testing or writing plots to disk instead).
output_filename – Name of the file to save the plot to (optional).
- Returns:
None
- dianna.visualization.plot_tabular(x: numpy.ndarray, y: List[str], x_label: str = 'Importance score', y_label: str = 'Features', num_features: int | None = None, show_plot: bool | None = True, output_filename: str | None = None) matplotlib.pyplot.Figure [source]
Plot feature importance with segments highlighted.
- Parameters:
x (np.ndarray) – Array of feature importance scores
y (List[str]) – List of feature names
x_label (str) – Label for the x-axis
y_label (str) – Label or list of labels for the y-axis
num_features (Optional[int]) – Number of most salient features to display
show_plot (bool, optional) – Shows plot if true (for testing or writing plots to disk instead).
output_filename (str, optional) – Name of the file to save the plot to (optional).
- Returns:
plt.Figure
- dianna.visualization.highlight_text(explanation, input_tokens=None, show_plot=True, output_filename=None, colormap='RdBu', alpha=1.0, heatmap_range=(-1, 1))[source]
Highlights a given text based on values in a given explanation object.
- Parameters:
explanation – list of tuples of (word, index of word in original data, importance)
input_tokens – list of all tokens (including those without importance)
show_plot – Shows plot if true (for testing or writing plots to disk instead)
output_filename – Name of the file to save the plot to (optional).
colormap – color map for the heatmap plot (see mpl.Axes.imshow documentation for options).
alpha – alpha value for the heatmap plot.
heatmap_range – a tuple (vmin, vmax) to set the range of the heatmap.
- Returns:
None
- dianna.visualization.plot_timeseries(x: numpy.ndarray, y: numpy.ndarray, segments: List[Dict[str, Any]], x_label: str = 't', y_label: str | Iterable[str] = None, cmap: str | None = None, show_plot: bool | None = True, output_filename: str | None = None, heatmap_range=(-1, 1)) matplotlib.pyplot.Figure [source]
Plot timeseries with segments highlighted.
- Parameters:
x (np.ndarray) – X-values with shape (number of time_steps)
y (np.ndarray) – Y-values with shape (number_of_time_steps, number_of_channels)
segments (List[Dict[str, Any]]) – Segment data, must be a list of dicts with the following keys: ‘index’, ‘start’, ‘end’, ‘weight’, ‘channel. Here, index is the index of the segment of feature, start and end determine the location of the segment, weight determines the color, and ‘channel’ determines the channel within the timeseries.
x_label (str, optional) – Label for the x-axis
y_label (Union[str, Iterable[str]], optional) – Label or list of labels for the y-axis
cmap (str, optional) – Matplotlib colormap
show_plot (bool, optional) – Shows plot if true (for testing or writing plots to disk instead).
output_filename (str, optional) – Name of the file to save the plot to (optional).
heatmap_range (tuple, optional) – a tuple (vmin, vmax) to set the range of the heatmap.
- Returns:
plt.Figure