Source code for rise_timeseries

from typing import Optional
import numpy as np
from dianna import utils
from dianna.utils.maskers import generate_time_series_masks
from dianna.utils.maskers import mask_data
from dianna.utils.predict import make_predictions
from dianna.utils.rise_utils import normalize


[docs] class RISETimeseries: """RISE implementation for timeseries adapted from the image version of RISE.""" def __init__( self, n_masks: int = 1000, feature_res: int = 8, p_keep: float = 0.5, preprocess_function: Optional[callable] = None, keep_masks: bool = False, keep_masked_data: bool = False, keep_predictions: bool = False, ) -> np.ndarray: """RISE initializer. Args: n_masks: Number of masks to generate. feature_res: Resolution of features in masks. p_keep: Fraction of input data to keep in each mask (Default: auto-tune this value). preprocess_function: Function to preprocess input data with keep_masks: keep masks in memory for the user to inspect keep_masked_data: keep masked data in memory for the user to inspect keep_predictions: keep model predictions in memory for the user to inspect """
[docs] self.n_masks = n_masks
[docs] self.feature_res = feature_res
[docs] self.p_keep = p_keep
[docs] self.preprocess_function = preprocess_function
[docs] self.masks = None
[docs] self.masked = None
[docs] self.predictions = None
[docs] self.keep_masks = keep_masks
[docs] self.keep_masked_data = keep_masked_data
[docs] self.keep_predictions = keep_predictions
[docs] def explain(self, model_or_function, input_timeseries, labels, batch_size=100, mask_type='mean'): """Runs the RISE explainer on images. The model will be called with masked timeseries, with a shape defined by `batch_size` and the shape of `input_data`. Args: model_or_function (callable or str): The function that runs the model to be explained _or_ the path to a ONNX model on disk. input_timeseries (np.ndarray): Input timeseries data to be explained batch_size (int): Batch size to use for running the model. labels (Iterable(int)): Labels to be explained mask_type: Masking strategy for masked values. Choose from 'mean' or a callable(input_timeseries) Returns: Explanation heatmap for each class (np.ndarray). """ runner = utils.get_function( model_or_function, preprocess_function=self.preprocess_function) masks = generate_time_series_masks(input_timeseries.shape, number_of_masks=self.n_masks, feature_res=self.feature_res, p_keep=self.p_keep) self.masks = masks if self.keep_masks else None masked = mask_data(input_timeseries, masks, mask_type=mask_type) self.masked = masked if self.keep_masked_data else None predictions = make_predictions(masked, runner, batch_size) self.predictions = predictions if self.keep_predictions else None n_labels = predictions.shape[1] saliency = predictions.T.dot(masks.reshape(self.n_masks, -1)).reshape( n_labels, *input_timeseries.shape) selected_saliency = saliency[labels] return normalize(selected_saliency, self.n_masks, self.p_keep)