Logo_ER10

Model Interpretation for Pretrained Binary MNIST Model using KernelSHAP

This notebook demonstrates how to apply KernelSHAP method on pretrained binary MNIST model using a hand-written digit image. It visualizes the relevance attributions for each pixel/super-pixel by displaying them on the image.

SHapley Additive exPlanations, in short, SHAP, is a model-agnostic explainable AI approach which is used to decrypt the black-box models through estimating the Shapley values.

KernelSHAP is a variant of SHAP. It is a method that uses the LIME framework to compute Shapley Values.

More details about this method can be found in the paper https://arxiv.org/abs/1705.07874.

Colab Setup

[1]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]

  # download data used in this demo
  import os
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
  paths_to_download = ['./data/binary-mnist.npz', './models/mnist_model_tf.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

0 - Libraries

[ ]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
import dianna
import onnx
from onnx_tf.backend import prepare
import matplotlib.pyplot as plt
from pathlib import Path

1 - Loading the model and the dataset

Loads pretrained binary MNIST model and the image to be explained.

Load saved binary MNIST data.

[2]:
# load dataset
data = np.load(Path('..','..','..','dianna','data', 'binary-mnist.npz'))
# load testing data and the related labels
X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 255
y_test = data['y_test']

Load the pretrained binary MNIST model.

[3]:
# Load saved onnx model
onnx_model_path = Path('..','..','..','dianna','models', 'mnist_model_tf.onnx')
onnx_model = onnx.load(onnx_model_path)
# get the output node
output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]

Print class and image of a single instance in the test data for preview.

[4]:
# class name
class_name = ['digit 0', 'digit 1']
# instance index
i_instance = 1
# select instance for testing
test_sample = X_test[i_instance].copy().astype(np.float32)
# model predictions with added batch axis to test sample
predictions = prepare(onnx_model).run(test_sample[None, ...])[f'{output_node}']
pred_class = class_name[np.argmax(predictions)]
print("The predicted class is:", pred_class)
plt.imshow(X_test[i_instance][:,:,0], cmap='gray')  # 0 for channel
The predicted class is: digit 0
[4]:
<matplotlib.image.AxesImage at 0x7fc60a4564d0>
../_images/tutorials_kernel_mnist_11_2.png

2 - Compute Shapley values and visualize the relevance attributions

Approximate Shapley values using KernelSHAP and visualize the relevance attributions on the image.

KernelSHAP approximate Shapley values in the LIME framework. The user need to specified the number of times to re-evaluate the model when explaining each prediction (nsamples). A binary mask need to be applied to the image to represent if an image region is hidden. It requires the background color for the masked image, which can be specified by background.

Performing KernelSHAP on each pixel is inefficient. It is always a good practice to segment the input image and perform computations on the obtained superpixels. This requires the user to specify some keyword arguments related to the segmentation, like the (approximate) number of labels in the segmented output image (n_segments), and width of Gaussian smoothing kernel for pre-processing for each dimension of the image (sigma).

[7]:
# use KernelSHAP to explain the network's predictions
shap_values = dianna.explain_image(onnx_model_path, test_sample, labels=[0, 1],
                                  method="KernelSHAP", nsamples=1000,
                                  background=0, n_segments=200, sigma=0,
                                  axis_labels=('height','width','channels'))

Visualize Shapley scores on the images.

[8]:
# get the index of predictions
top_preds = np.argsort(-predictions)
inds = top_preds[0]
# Visualize the explanations
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,4))
axes[0].imshow(test_sample, cmap='gray')
axes[0].axis('off')
# get the range for color bar
max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])
# plot the test image and the attributions on the image for each class
for i in range(2):
    m = shap_values[inds[i]]
    axes[i+1].set_title(str(inds[i]))
    axes[i+1].imshow(test_sample, alpha=0.15)
    im = axes[i+1].imshow(m, vmin=-max_val, vmax=max_val, cmap='bwr')
    #axes[i+1].axis('off')
    axes[i+1].set_xticks([])
    axes[i+1].set_yticks([])
cb = fig.colorbar(im, ax=axes.ravel().tolist(), label="SHAP value", orientation="horizontal", aspect=60)
cb.outline.set_visible(False)
plt.show()
../_images/tutorials_kernel_mnist_15_0.png

3 - Conclusions

The Shapley scores are estimated using KernelSHAP for models used to categorize the binary MNIST. The example here shows that the KernelSHAP method evaluates the importance of each segmentation/super pixel to the classification and the results are reasonable compared to the human visual preception of the chosen testing hand-written digit image.

[ ]: