rise_image ========== .. py:module:: rise_image Classes ------- .. autoapisummary:: rise_image.RISEImage Module Contents --------------- .. py:class:: RISEImage(n_masks=1000, feature_res=8, p_keep=None, axis_labels=None, preprocess_function=None) RISE implementation for images based on https://github.com/eclique/RISE/blob/master/Easy_start.ipynb. .. py:attribute:: n_masks .. py:attribute:: feature_res .. py:attribute:: p_keep .. py:attribute:: preprocess_function .. py:attribute:: masks :value: None .. py:attribute:: predictions :value: None .. py:attribute:: axis_labels .. py:method:: explain(model_or_function, input_data, labels, batch_size=100) Runs the RISE explainer on images. The model will be called with masked images, with a shape defined by `batch_size` and the shape of `input_data`. :param model_or_function: The function that runs the model to be explained _or_ the path to a ONNX model on disk. :type model_or_function: callable or str :param input_data: Image to be explained :type input_data: np.ndarray :param batch_size: Batch size to use for running the model. :type batch_size: int :param labels: Labels to be explained :type labels: Iterable(int) :returns: Explanation heatmap for each class (np.ndarray). .. py:method:: _prepare_input_data_and_model(input_data, model_or_function) Prepares the input data as an xarray with an added batch dimension and creates a preprocessing function. .. py:method:: _set_axis_labels(input_data) .. py:method:: _determine_p_keep(input_data, runner, n_masks=100) See n_mask default value https://github.com/dianna-ai/dianna/issues/24#issuecomment-1000152233. .. py:method:: _calculate_max_class_std(p_keep, runner, input_data, n_masks) .. py:method:: _prepare_image_data(input_data) Transforms the data to be of the shape and type RISE expects. :param input_data: Data to be explained :type input_data: xarray :returns: transformed input data, preprocessing function to use with utils.get_function() .. py:method:: _get_full_preprocess_function(channel_axis_index, dtype) Creates a full preprocessing function. Creates a preprocessing function that incorporates both the (optional) user's preprocessing function, as well as any needed dtype and shape conversions :param channel_axis_index: Axis index of the channels in the input data :type channel_axis_index: int :param dtype: Data type of the input data (e.g. np.float32) :type dtype: type :returns: Function that first ensures the data has the same shape and type as the input data, then runs the users' preprocessing function