Logo_ER10

General demo of using dianna

DIANNA is a Python package that brings explainable AI (XAI) to your research project. It wraps carefully selected XAI methods in a simple, uniform interface. It’s built by, with and for (academic) researchers and research software engineers working on machine learning projects.

Colab Setup

[1]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]

  # download data used in this demo
  import os
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
  paths_to_download = ['./data/binary-mnist.npz', './models/mnist_model_tf.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

Libraries

[ ]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
import dianna
import onnx
import onnxruntime
from onnx_tf.backend import prepare
from scipy.special import softmax
import matplotlib.pyplot as plt
%matplotlib inline
from dianna import visualization
import random
from pathlib import Path
random.seed(42)

Model Interpretation for Pretrained Binary MNIST Model

Loads pretrained binary MNIST model and the image to be explained.

Load saved binary MNIST data.

[2]:
# load dataset
data_path = Path('..', 'dianna','data', 'binary-mnist.npz')
data = np.load(data_path)
# load testing data and the related labels
X_test = data['X_test'].astype(np.float32).reshape([-1, 28, 28, 1]) / 256
y_test = data['y_test']

Load the pretrained binary MNIST model.

[3]:
# Load saved onnx model
onnx_model_path = Path('..','dianna','models', 'mnist_model_tf.onnx')
onnx_model = onnx.load(onnx_model_path)
# get the output node
output_node = prepare(onnx_model, gen_tensor_dict=True).outputs[0]

Print class and image of a single instance in the test data for preview.

[4]:
# class name
class_name = ['digit 0', 'digit 1']
# instance index
i_instance = 3
# select instance for testing
test_sample = X_test[i_instance]
# model predictions with added batch axis to test sample
predictions = prepare(onnx_model).run(test_sample[None, ...])[f'{output_node}']
pred_class = class_name[np.argmax(predictions)]
# get the index of predictions
top_preds = np.argsort(-predictions)
inds = top_preds[0]
print("The predicted class for this test image is:", pred_class)
plt.imshow(X_test[i_instance][:,:,0], cmap='gray')  # 0 for channel
The predicted class for this test image is: digit 0
[4]:
<matplotlib.image.AxesImage at 0x7f750567ca90>
../_images/tutorials_demo_12_2.png

1. KernelSHAP

SHapley Additive exPlanations, in short, SHAP, is a model-agnostic explainable AI approach which is used to decrypt the black-box models through estimating the Shapley values.

KernelSHAP is a variant of SHAP. It is a method that uses the LIME framework to compute Shapley Values, and visualizes the relevance attributions for each pixel/super-pixel by displaying them on the image.

The user need to specified the number of times to re-evaluate the model when explaining each prediction (nsamples). A binary mask need to be applied to the image to represent if an image region is hidden. It requires the background color for the masked image, which can be specified by background.

Performing KernelSHAP on each pixel is inefficient. It is always a good practice to segment the input image and perform computations on the obtained superpixels. This requires the user to specify some keyword arguments related to the segmentation, like the (approximate) number of labels in the segmented output image (n_segments), and width of Gaussian smoothing kernel for pre-processing for each dimension of the image (sigma).

[24]:
# use KernelSHAP to explain the network's predictions
shap_values = dianna.explain_image(onnx_model_path, test_sample,
                                  method="KernelSHAP", labels=[0], nsamples=1000,
                                  background=0, n_segments=200, sigma=0,
                                  axis_labels=('height','width','channels'))
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.48s/it]

Visualize Shapley scores on the images.

[25]:
# Visualize the explanations
# get the range for color bar
max_val = np.max([np.max(np.abs(shap_values[i][:,:-1])) for i in range(len(shap_values))])
# plot the test image and the attributions on the image for class 0
print(f'Explaination for `{pred_class}` with KernelSHAP')

fig, _ = visualization.plot_image(shap_values[0], heatmap_cmap='bwr', heatmap_range=(-max_val, max_val))
Explaination for `digit 0` with KernelSHAP
../_images/tutorials_demo_17_1.png

The Shapley scores are estimated using KernelSHAP for models used to categorize the binary MNIST. The example here shows that the KernelSHAP method evaluates the importance of each segmentation/super pixel to the classification and the results are reasonable compared to the human visual preception of the chosen testing hand-written digit image.

2. RISE

RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates the relevance empirically by probing the model with randomly masked versions of the input image to obtain the corresponding outputs.

RISE masks random portions of the input image and passes the masked image through the model — the portion that decreases the accuracy the most is the most “important” portion. To call the explainer and generate the relevance scores, the user need to specified the number of masks being randomly generated (n_masks), the resolution of features in masks (feature_res) and for each mask and each feature in the image, the probability of being kept unmasked (p_keep).

[27]:
relevances = dianna.explain_image(onnx_model_path, test_sample, method="RISE",
                                labels=[i for i in range(2)],
                                n_masks=5000, feature_res=8, p_keep=.1,
                                axis_labels=('height','width','channels'))[0]
Explaining: 100%|████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:00<00:00, 53.72it/s]

Visualize the relevance scores for the predicted class on top of the image.

[28]:
print(f'Explaination for `{pred_class}` with RISE')
fig, _ = visualization.plot_image(relevances, X_test[i_instance][:,:,0], data_cmap='gray', heatmap_cmap='bwr')
Explaination for `digit 0` with RISE
../_images/tutorials_demo_23_1.png

3. LIME

LIME (Local Interpretable Model-agnostic Explanations) is an explainable-AI method that aims to create an interpretable model that locally represents the classifier.

[29]:
# need to preprocess, because we divided the input data by 256 for the models and LIME needs RGB values
def preprocess_function(image):
    return (image / 256).astype(np.float32)

# An explanation is returned for each label, but we ask for just one label so the output is a list of length one.
explanation_heatmap = dianna.explain_image(onnx_model_path, test_sample * 256, 'LIME',
                                           axis_labels=('height','width','channels'),
                                           random_state=2,
                                           labels=[i for i in range(2)], preprocess_function=preprocess_function)
100%|██████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:03<00:00, 1299.05it/s]
[30]:
print(f'Explaination for `{pred_class}` with LIME')
fig, _ = visualization.plot_image(explanation_heatmap[0], X_test[i_instance][:,:,0], heatmap_cmap='bwr')

Explaination for `digit 0` with LIME
../_images/tutorials_demo_26_1.png
[ ]: