Source code for dianna.visualization.timeseries

from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Union
import matplotlib.pyplot as plt
import numpy as np
from matplotlib import cm


[docs] def plot_timeseries( x: np.ndarray, y: np.ndarray, segments: List[Dict[str, Any]], x_label: str = 't', y_label: Union[str, Iterable[str]] = None, cmap: Optional[str] = 'bwr', show_plot: Optional[bool] = True, output_filename: Optional[str] = None, heatmap_range=(-1, 1), ) -> plt.Figure: """Plot timeseries with segments highlighted. Args: x (np.ndarray): X-values with shape (number of time_steps) y (np.ndarray): Y-values with shape (number_of_time_steps, number_of_channels) segments (List[Dict[str, Any]]): Segment data, must be a list of dicts with the following keys: 'index', 'start', 'end', 'weight', 'channel. Here, `index` is the index of the segment of feature, `start` and `end` determine the location of the segment, `weight` determines the color, and 'channel' determines the channel within the timeseries. x_label (str, optional): Label for the x-axis y_label (Union[str, Iterable[str]], optional): Label or list of labels for the y-axis cmap (str, optional): Matplotlib colormap show_plot (bool, optional): Shows plot if true (for testing or writing plots to disk instead). output_filename (str, optional): Name of the file to save the plot to (optional). heatmap_range (tuple, optional): a tuple (vmin, vmax) to set the range of the heatmap. Returns: plt.Figure """ fig, axs, y_labels, ys = _process_plotting_parameters(x, y, y_label) for y_current, y_label_current, ax_current in zip(ys.T, y_labels, axs): current_ax = ax_current current_ax.plot(x, y_current, label=y_label_current) current_ax.set_xlabel(x_label) current_ax.set_ylabel(y_label_current) current_ax.label_outer() _draw_segments(axs, cmap, segments, heatmap_range) if not show_plot: plt.close() if output_filename: plt.savefig(output_filename) return fig, axs
[docs] def _draw_segments(axs, cmap, segments, heatmap_range): vmin, vmax = heatmap_range cmap = plt.get_cmap(cmap) norm = plt.Normalize(vmin, vmax) for segment in segments: start = segment['start'] stop = segment['stop'] weight = segment['weight'] segment['index'] channel = segment.get('channel', 0) color = cmap(norm(weight)) axs[channel].axvspan(start, stop, color=color, alpha=0.5) plt.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), ax=axs, label='weights')
[docs] def _process_plotting_parameters(x, y, y_labels): if hasattr(x, 'ndim') and x.ndim != 1: raise ValueError( f'Invalid rank {x.ndim}. Data x can only have 1 dimension.') if y.ndim == 1: ys = np.expand_dims(y, 1) elif y.ndim == 2: if y.shape[0] != len(x): raise ValueError( f'Shape y was {y.shape} but should be ({len(x)}, ?) instead to be compatible with x.' ) ys = y else: raise ValueError( f'Invalid rank {y.ndim}. Data y can only have either 1 or 2 dimensions.' ) if not y_labels: y_labels = [f'channel {c}' for c in range(ys.shape[0])] if isinstance(y_labels, str): y_labels = [y_labels] n_channels = ys.shape[1] fig, axs = plt.subplots(nrows=n_channels, sharex=True) if n_channels == 1: axs = (axs, ) return fig, axs, y_labels, ys