Logo_ER10

Model Interpretation for Pretrained Model to Recognize Geometric Shapes using KernelSHAP

This notebook demonstrates how to apply the KernelSHAP explainability method on a pretrained model used to classify geometric shapes. The relevance attributions for each pixel/super-pixel are visualized by displaying them on the image.

SHapley Additive exPlanations, in short, SHAP, is a model-agnostic explainable AI approach which is used to decrypt the black-box models through estimating the Shapley values.

KernelSHAP is a variant of SHAP. It is a method that uses the LIME framework to compute Shapley Values.

More details about this method can be found in the paper https://arxiv.org/abs/1705.07874.

Colab Setup

[2]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]

  # download data used in this demo
  import os
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
  paths_to_download = ['./data/shapes.npz', './models/geometric_shapes_model.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

0 - Libraries

[ ]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
import dianna
import onnx
from onnx_tf.backend import prepare
import matplotlib.pyplot as plt
from pathlib import Path

1 - Loading the model and the dataset

Loads pretrained model and the image to be explained.

Load saved geometric shapes data.

[2]:
# load dataset
data = np.load(Path('..','..','..','dianna', 'data', 'shapes.npz'))
# load testing data and the related labels
X_test = data['X_test'].astype(np.float32).reshape([-1, 1, 64, 64])
y_test = data['y_test']

Load the pretrained binary MNIST model.

[3]:
# Load saved onnx model
onnx_model_path = Path('..','..','..','dianna','models', 'geometric_shapes_model.onnx')
onnx_model = onnx.load(onnx_model_path)
# get the output node
output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]

Print class and image of a single instance in the test data for preview.

[5]:
# class name
class_name = ['circle', 'triangle']
# instance index
i_instance = 4
# select instance for testing
test_sample = X_test[i_instance].copy().astype(np.float32)
# model predictions with added batch axis to test sample
predictions = prepare(onnx_model).run(test_sample[None, ...])[f'{output_node}']
pred_class = class_name[np.argmax(predictions)]
print("The predicted class is:", pred_class)
plt.imshow(X_test[i_instance][0,:,:], cmap='gray')
The predicted class is: triangle
2024-04-09 15:01:00.811938: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
[5]:
<matplotlib.image.AxesImage at 0x7f61696a5a90>
../_images/tutorials_kernelshap_geometric_shapes_11_3.png

2 - Compute Shapley values and visualize the relevance attributions

Approximate Shapley values using KernelSHAP and visualize the relevance attributions on the image.

KernelSHAP approximate Shapley values in the LIME framework. The user need to specified the number of times to re-evaluate the model when explaining each prediction (nsamples). A binary mask need to be applied to the image to represent if an image region is hidden. It requires the background color for the masked image, which can be specified by background.

Performing KernelSHAP on each pixel is inefficient. It is always a good practice to segment the input image and perform computations on the obtained superpixels. This requires the user to specify some keyword arguments related to the segmentation, like the (approximate) number of labels in the segmented output image (n_segments), and width of Gaussian smoothing kernel for pre-processing for each dimension of the image (sigma).

[7]:
# use KernelSHAP to explain the network's predictions
shap_values = dianna.explain_image(onnx_model_path, test_sample,
                                  method="KernelSHAP", labels=[0, 1], nsamples=2000,
                                  n_segments=200, sigma=0,
                                  axis_labels=('channels','height','width'))
WARNING:tensorflow:6 out of the last 6 calls to <function BackendTFModule.__call__ at 0x7f606be6f880> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for  more details.
2024-04-09 15:02:48.083052: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
2024-04-09 15:02:48.292576: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
2024-04-09 15:02:58.421348: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
2024-04-09 15:02:58.438724: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 524288000 exceeds 10% of free system memory.
2024-04-09 15:02:58.751153: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 524288000 exceeds 10% of free system memory.
2024-04-09 15:03:00.410805: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9
2024-04-09 15:03:10.750838: W tensorflow/core/grappler/optimizers/loop_optimizer.cc:933] Skipping loop optimization for Merge node with control input: assert_equal_1/Assert/AssertGuard/branch_executed/_9

Visualize Shapley scores on the images.

[11]:
# get the index of predictions
top_preds = np.argsort(-predictions)
inds = top_preds[0]
# Visualize the explanations
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(10,4))
axes[0].imshow(test_sample[0], cmap='gray')
axes[0].axis('off')
# get the range for color bar
max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])
# plot the test image and the attributions on the image for each class
for i in range(2):
    m = shap_values[inds[i]]
    axes[i+1].set_title(class_name[inds[i]])
    axes[i+1].imshow(test_sample[0], alpha=0.15)
    im = axes[i+1].imshow(m, cmap='bwr',vmin=-max_val, vmax=max_val)
    #axes[i+1].axis('off')
    axes[i+1].set_xticks([])
    axes[i+1].set_yticks([])
cb = fig.colorbar(im, ax=axes.ravel().tolist(), label="SHAP value", orientation="horizontal", aspect=60)
cb.outline.set_visible(False)
plt.show()
../_images/tutorials_kernelshap_geometric_shapes_15_0.png

3 - Conclusions

The Shapley scores are estimated using KernelSHAP. The example here shows that the KernelSHAP method evaluates the importance of each segmentations/super pixels to the classification of geometric shapes and the results indicate that the model determines the shape by checking whether there is a (sharp) angle or not. For instance, the figure above shows that the sharp angle leads to negative scores against circle, and therefore the prediction is triangle. The interpretation agrees with the human visual preception of the chosen image.

[ ]: