rise_image

Module Contents

Classes

RISEImage

RISE implementation for images based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb.

class rise_image.RISEImage(n_masks=1000, feature_res=8, p_keep=None, axis_labels=None, preprocess_function=None)[source]

RISE implementation for images based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb.

explain(model_or_function, input_data, labels, batch_size=100)[source]

Runs the RISE explainer on images.

The model will be called with masked images, with a shape defined by batch_size and the shape of input_data.

Parameters:
  • model_or_function (callable or str) – The function that runs the model to be explained _or_ the path to a ONNX model on disk.

  • input_data (np.ndarray) – Image to be explained

  • batch_size (int) – Batch size to use for running the model.

  • labels (Iterable(int)) – Labels to be explained

Returns:

Explanation heatmap for each class (np.ndarray).

_prepare_input_data_and_model(input_data, model_or_function)[source]

Prepares the input data as an xarray with an added batch dimension and creates a preprocessing function.

_set_axis_labels(input_data)[source]
_determine_p_keep(input_data, runner, n_masks=100)[source]

See n_mask default value https://github.com/dianna-ai/dianna/issues/24#issuecomment-1000152233.

_calculate_max_class_std(p_keep, runner, input_data, n_masks)[source]
_prepare_image_data(input_data)[source]

Transforms the data to be of the shape and type RISE expects.

Parameters:

input_data (xarray) – Data to be explained

Returns:

transformed input data, preprocessing function to use with utils.get_function()

_get_full_preprocess_function(channel_axis_index, dtype)[source]

Creates a full preprocessing function.

Creates a preprocessing function that incorporates both the (optional) user’s preprocessing function, as well as any needed dtype and shape conversions

Parameters:
  • channel_axis_index (int) – Axis index of the channels in the input data

  • dtype (type) – Data type of the input data (e.g. np.float32)

Returns:

Function that first ensures the data has the same shape and type as the input data, then runs the users’ preprocessing function