dianna.visualization

Tools for visualization of model explanations.

Submodules

Package Contents

Functions

plot_image(heatmap[, original_data, heatmap_cmap, ...])

Plots a heatmap image.

plot_tabular(→ matplotlib.pyplot.Figure)

Plot feature importance with segments highlighted.

highlight_text(explanation[, input_tokens, show_plot, ...])

Highlights a given text based on values in a given explanation object.

plot_timeseries() → matplotlib.pyplot.Figure)

Plot timeseries with segments highlighted.

dianna.visualization.plot_image(heatmap, original_data=None, heatmap_cmap=None, heatmap_range=(None, None), data_cmap=None, show_plot=True, output_filename=None)[source]

Plots a heatmap image.

Optionally, the heatmap (typically a saliency map of an explainer) can be plotted on top of the original data. In that case both images are plotted transparantly with alpha = 0.5.

Parameters:
  • heatmap – the saliency map or other heatmap to be plotted.

  • original_data – the data to plot together with the heatmap, both with alpha = 0.5 (optional).

  • heatmap_cmap – color map for the heatmap plot (see mpl.Axes.imshow documentation for options).

  • heatmap_range – a tuple (vmin, vmax) to set the range of the heatmap. By default, the colormap covers the complete value range of the supplied heatmap.

  • data_cmap – color map for the (optional) data image (see mpl.Axes.imshow documentation for options). By default, if the image is two dimensional, the color map is set to ‘gray’.

  • show_plot – Shows plot if true (for testing or writing plots to disk instead).

  • output_filename – Name of the file to save the plot to (optional).

Returns:

None

dianna.visualization.plot_tabular(x: numpy.ndarray, y: List[str], x_label: str = 'Importance score', y_label: str = 'Features', num_features: int | None = None, show_plot: bool | None = True, output_filename: str | None = None) matplotlib.pyplot.Figure[source]

Plot feature importance with segments highlighted.

Parameters:
  • x (np.ndarray) – Array of feature importance scores

  • y (List[str]) – List of feature names

  • x_label (str) – Label for the x-axis

  • y_label (str) – Label or list of labels for the y-axis

  • num_features (Optional[int]) – Number of most salient features to display

  • show_plot (bool, optional) – Shows plot if true (for testing or writing plots to disk instead).

  • output_filename (str, optional) – Name of the file to save the plot to (optional).

Returns:

plt.Figure

dianna.visualization.highlight_text(explanation, input_tokens=None, show_plot=True, output_filename=None, colormap='RdBu', alpha=1.0, heatmap_range=(-1, 1))[source]

Highlights a given text based on values in a given explanation object.

Parameters:
  • explanation – list of tuples of (word, index of word in original data, importance)

  • input_tokens – list of all tokens (including those without importance)

  • show_plot – Shows plot if true (for testing or writing plots to disk instead)

  • output_filename – Name of the file to save the plot to (optional).

  • colormap – color map for the heatmap plot (see mpl.Axes.imshow documentation for options).

  • alpha – alpha value for the heatmap plot.

  • heatmap_range – a tuple (vmin, vmax) to set the range of the heatmap.

Returns:

None

dianna.visualization.plot_timeseries(x: numpy.ndarray, y: numpy.ndarray, segments: List[Dict[str, Any]], x_label: str = 't', y_label: str | Iterable[str] = None, cmap: str | None = None, show_plot: bool | None = True, output_filename: str | None = None, heatmap_range=(-1, 1)) matplotlib.pyplot.Figure[source]

Plot timeseries with segments highlighted.

Parameters:
  • x (np.ndarray) – X-values with shape (number of time_steps)

  • y (np.ndarray) – Y-values with shape (number_of_time_steps, number_of_channels)

  • segments (List[Dict[str, Any]]) – Segment data, must be a list of dicts with the following keys: ‘index’, ‘start’, ‘end’, ‘weight’, ‘channel. Here, index is the index of the segment of feature, start and end determine the location of the segment, weight determines the color, and ‘channel’ determines the channel within the timeseries.

  • x_label (str, optional) – Label for the x-axis

  • y_label (Union[str, Iterable[str]], optional) – Label or list of labels for the y-axis

  • cmap (str, optional) – Matplotlib colormap

  • show_plot (bool, optional) – Shows plot if true (for testing or writing plots to disk instead).

  • output_filename (str, optional) – Name of the file to save the plot to (optional).

  • heatmap_range (tuple, optional) – a tuple (vmin, vmax) to set the range of the heatmap.

Returns:

plt.Figure