import logging
import numpy as np
import shap
import skimage.segmentation
from dianna import utils
from dianna._logging_utils import LoggingContext
[docs]
class KERNELSHAPImage:
"""Kernel SHAP implementation based on shap https://github.com/slundberg/shap."""
def __init__(self, axis_labels=None, preprocess_function=None):
"""Kernelshap initializer.
Arguments:
axis_labels (dict/list, optional): If a dict, key,value pairs of axis index, name.
If a list, the name of each axis where the index
in the list is the axis index
preprocess_function (callable, optional): Function to preprocess input data with
"""
[docs]
self.preprocess_function = preprocess_function
[docs]
self.axis_labels = axis_labels if axis_labels is not None else []
@staticmethod
[docs]
def _segment_image(image, n_segments, compactness, sigma, **kwargs):
"""Create segmentation to explain by segment, not every pixel.
This could help speed-up the calculation when the input size is very large.
This function segments image using k-means clustering in Color-(x,y,z) space. It uses
scikit-image.
Args:
image (np.ndarray): Input image to be segmented.
n_segments (int): The (approximate) number of labels in the segmented output image
compactness (int): Balances color proximity and space proximity.
sigma (float): Width of Gaussian smoothing kernel
kwargs: These keyword parameters are passed on
Returns:
integer mask indicating segment labels with the shape of the input image.
the number of segments is less than or equal to n_segments.
Check keyword arguments for the skimage.segmentation.slic function
via the following link:
https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
"""
image_segments = skimage.segmentation.slic(image=image,
n_segments=n_segments,
compactness=compactness,
sigma=sigma,
**kwargs)
return image_segments
[docs]
def explain(
self,
model_or_function,
input_data,
labels,
nsamples='auto',
background=None,
n_segments=100,
compactness=10.0,
sigma=0,
**kwargs,
):
"""Run the KernelSHAP explainer.
The model will be called with the function of image segmentation.
Args:
model_or_function (str): The path to a ONNX model on disk.
input_data (np.ndarray): Data to be explained. It is mandatory to only
provide a single example as input. This is because
KernelShap is generally used for sample-based
interpretability, training a separate interpretable
model to explain a model prediction on each individual
example. The input dimension must be
[batch, height, width, color_channels] or
[batch, color_channels, height, width] (see axis_labels)
labels (Iterable(int)): Indices of classes to be explained
nsamples ("auto" or int): Number of times to re-evaluate the model when
explaining each prediction. More samples lead
to lower variance estimates of the SHAP values.
The "auto" setting uses
`nsamples = 2 * X.shape[1] + 2048`
background (int): Background color for the masked image
n_segments (int): The (approximate) number of labels in the segmented output image
compactness (int): Balances color proximity and space proximity. Higher values give
more weight to space proximity, making superpixel shapes more
square/cubic.
sigma (float): Width of Gaussian smoothing kernel for pre-processing for
each dimension of the image. Zero means no smoothing.
kwargs: These keyword parameters are passed on
Other keyword arguments: see the documentation of kernel explainer of SHAP
(also in function "shap_values") via:
https://github.com/slundberg/shap/blob/master/shap/explainers/_kernel.py
and the documentation of image segmentation via:
https://scikit-image.org/docs/dev/api/skimage.segmentation.html#skimage.segmentation.slic
Returns:
Explanation heatmap of Shapley values for each class (np.ndarray).
"""
self.onnx_model, self.input_node_name, self.input_node_dtype,\
self.output_node = utils.onnx_model_node_loader(model_or_function)
self.labels = labels
self.input_data = self._prepare_image_data(input_data)
self.background = background
# create onnxruntime session once for efficient repeated inference
import onnxruntime as rt # pylint: disable=import-outside-toplevel
self.onnx_session = rt.InferenceSession(
self.onnx_model.SerializeToString())
# other keyword arguments for the method segment_image
slic_kwargs = utils.get_kwargs_applicable_to_function(
skimage.segmentation.slic, kwargs)
# call the segment method to create segmentation of input image
self.image_segments = self._segment_image(self.input_data, n_segments,
compactness, sigma,
**slic_kwargs)
n_segments = np.unique(self.image_segments).size
# call the Kernel SHAP explainer
explainer = shap.KernelExplainer(
self._runner, np.zeros((len(self.labels), n_segments)))
# Temporarily hide warnings, because shap is very spammy
# `shap_values_list` has shape (n_samples, n_features, n_classes) in shap 0.46+
with LoggingContext(level=logging.CRITICAL):
shap_values_list = explainer.shap_values(np.ones(
(len(self.labels), n_segments)),
nsamples=nsamples)
# create heat_maps where shape is (n_classes, *image_segments.shape)
heat_maps = _create_heatemaps(shap_values_list, self.image_segments)
if labels is not None:
heat_maps = heat_maps[list(labels)]
return heat_maps
[docs]
def _prepare_image_data(self, input_data):
"""Transforms the data to be of the shape and type KernelSHAP expects.
Args:
input_data (NumPy-compatible array): Data to be explained
Returns:
transformed input data
"""
# automatically determine the location of the channels axis if no axis_labels were provided
axis_label_names = self.axis_labels.values() if isinstance(
self.axis_labels, dict) else self.axis_labels
if not axis_label_names:
channels_axis_index = utils.locate_channels_axis(input_data.shape)
self.axis_labels = {channels_axis_index: 'channels'}
elif 'channels' not in axis_label_names:
raise ValueError(
'When providing axis_labels it is required to provide the location'
' of the channels axis')
input_data = utils.to_xarray(input_data, self.axis_labels)
# ensure channels axis is last and keep track of where it was so we can move it back
self.channels_axis_index = input_data.dims.index('channels')
input_data = utils.move_axis(input_data, 'channels', -1)
return input_data
[docs]
def _mask_image(self,
features,
segmentation,
image,
background=None,
channels_axis_index=2,
datatype=np.float32):
"""Define a function that depends on a binary mask representing if an image region is hidden.
Args:
features (np.ndarray): A matrix of samples (# samples x # features)
on which to explain the model's output.
segmentation (np.ndarray): Image segmentations generated by
the function _segment_image
image (np.ndarray): Image to be explained
background (int): Background color for the masked image
channels_axis_index (int): See the function _prepare_image_data
datatype (np.dtype): Datatype for the returned value
"""
# check the background color
if background is None:
background = image.mean(axis=(0, 1))
# Create an empty 4D array
out = np.zeros((features.shape[0], image.shape[0], image.shape[1],
image.shape[2]))
for i in range(features.shape[0]):
out[i] = image
for j in range(features.shape[1]):
if features[i, j] == 0:
out[i][segmentation == j, :] = background
# the output shape should satisfy the requirement from onnx model input shape
if channels_axis_index != 2:
out = np.transpose(out, (0, 3, 1, 2))
return out.astype(datatype)
[docs]
def _runner(self, features):
"""Define a runner/wrapper to load models and values.
Args:
features (np.ndarray): A matrix of samples (# samples x # features)
on which to explain the model's output.
"""
model_input = self._mask_image(features, self.image_segments,
self.input_data, self.background,
self.channels_axis_index,
self.input_node_dtype)
if self.preprocess_function is not None:
model_input = self.preprocess_function(model_input)
return self.onnx_session.run(
[self.output_node],
{self.input_node_name: model_input})[0]
[docs]
def _create_heatemaps(shap_values_list, image_segments):
"""Create heatmaps from the shap values and the image segments.
The final heatmaps has a shape of (n_classes, *image_segments.shape).
"""
# shap 0.46+ returns an array of shape (n_samples, n_features, n_classes).
# Take sample 0 and transpose to (n_classes, n_features/n_segments).
shap_array = np.asarray(shap_values_list)
per_class_values = shap_array[0].T # (n_classes, n_segments)
n_classes = per_class_values.shape[0]
heat_maps = np.zeros((n_classes, *image_segments.shape))
# fill the heat_maps with shap values for each class and segment
for i, shap_values_for_class in enumerate(per_class_values):
class_heat_map = heat_maps[i]
for index in image_segments.flat:
class_heat_map[image_segments == index] = shap_values_for_class[index - 1]
heat_maps[i] = class_heat_map
return heat_maps