Logo_ER10

Model Interpretation for Pretrained ImageNet Model using RISE

This notebook demonstrates how to apply the RISE explainability method on pretrained ImageNet model using a bee image. It visualizes the relevance scores for all pixels/super-pixels by displaying them on the image.

RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates importance empirically by probing the model with randomly masked versions of the input image and obtaining the corresponding outputs.

More details about this method can be found in the paper https://arxiv.org/abs/1806.07421.

Colab Setup

[1]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]

  # download data used in this demo
  import os
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
  paths_to_download = ['./data/bee.jpg']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

0 - Libraries

[ ]:
import warnings
warnings.filterwarnings('ignore') # disable warnings relateds to versions of tf
import numpy as np
from pathlib import Path
# keras model and preprocessing tools
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input, decode_predictions
from keras import backend as K
from keras import utils
# dianna library for explanation
import dianna
from dianna import visualization
# for plotting
%matplotlib inline
from matplotlib import pyplot as plt

1 - Loading the model and the dataset

Loads pretrained ImageNet model and the image to be explained.

Initialize the pretrained model.

[2]:
class Model():
    def __init__(self):
        K.set_learning_phase(0)
        self.model = ResNet50()
        self.input_size = (224, 224)

    def run_on_batch(self, x):
        return self.model.predict(x)
[3]:
model = Model()

Load and preprocess image.

[4]:
def load_img(path):
    img = utils.load_img(path, target_size=model.input_size)
    x = utils.img_to_array(img)
    x = preprocess_input(x)
    return img, x

Call the function to load an image of a single instance in the test data from the img folder.

[5]:
img, x = load_img(Path('..','..','..','dianna','data', 'bee.jpg'))
plt.imshow(img)
[5]:
<matplotlib.image.AxesImage at 0x7f5808160cd0>
../_images/tutorials_rise_imagenet_12_1.png

2 - Compute and visualize the relevance scores

Compute the pixel relevance scores using RISE and visualize them on the input image.

RISE masks random portions of the input image and passes the masked image through the model — the masked portion that decreases accuracy the most is the most “important” portion. To call the explainer and generate relevance scores map, the user need to specifiy the number of masks being randomly generated (n_masks), the resolution of features in masks (feature_res) and for each mask and each feature in the image, the probability of being kept unmasked (p_keep).

[6]:
relevances = dianna.explain_image(model.run_on_batch, x, method="RISE",
                                labels=[i for i in range(1000)],
                                n_masks=1000, feature_res=6, p_keep=.1,
                                axis_labels={2: 'channels'})
Explaining:   0%|                                                                          | 0/10 [00:00<?, ?it/s]2024-04-09 15:28:31.433107: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 60211200 exceeds 10% of free system memory.
2024-04-09 15:28:32.552095: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19267584 exceeds 10% of free system memory.
2024-04-09 15:28:32.552117: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19267584 exceeds 10% of free system memory.
2024-04-09 15:28:32.552133: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 19267584 exceeds 10% of free system memory.
2024-04-09 15:28:32.561484: W external/local_tsl/tsl/framework/cpu_allocator_impl.cc:83] Allocation of 20313600 exceeds 10% of free system memory.
4/4 [==============================] - 6s 1s/step
Explaining:  10%|██████▌                                                           | 1/10 [00:05<00:53,  5.92s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  20%|█████████████▏                                                    | 2/10 [00:11<00:43,  5.47s/it]
4/4 [==============================] - 4s 984ms/step
Explaining:  30%|███████████████████▊                                              | 3/10 [00:15<00:35,  5.06s/it]
4/4 [==============================] - 4s 981ms/step
Explaining:  40%|██████████████████████████▍                                       | 4/10 [00:20<00:29,  4.85s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  50%|█████████████████████████████████                                 | 5/10 [00:24<00:23,  4.80s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  60%|███████████████████████████████████████▌                          | 6/10 [00:29<00:19,  4.77s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  70%|██████████████████████████████████████████████▏                   | 7/10 [00:34<00:14,  4.84s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  80%|████████████████████████████████████████████████████▊             | 8/10 [00:39<00:09,  4.88s/it]
4/4 [==============================] - 5s 1s/step
Explaining:  90%|███████████████████████████████████████████████████████████▍      | 9/10 [00:44<00:04,  4.85s/it]
4/4 [==============================] - 5s 1s/step
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:49<00:00,  4.92s/it]

Make predictions and select the top prediction.

[7]:
def class_name(idx):
    return decode_predictions(np.eye(1, 1000, idx))[0][0][1]

# print the name of predicted class, taking care of adding a batch axis to the model input
class_name(np.argmax(model.model.predict(x[None, ...])))
1/1 [==============================] - 0s 56ms/step
[7]:
'bee'

Visualize the relevance scores for the predicted class on top of the input image.

[8]:
predictions = model.model.predict(x[None, ...])
prediction_ids = np.argsort(predictions)[0][-1:-6:-1]
prediction_ids
1/1 [==============================] - 0s 57ms/step
[8]:
array([309, 946, 308, 319,  74])
[12]:
for class_idx in prediction_ids:
    print(f'Explanation for `{class_name(class_idx)}` ({predictions[0][class_idx]})')
    visualization.plot_image(relevances[class_idx], utils.img_to_array(img)/255., heatmap_cmap='jet')
    plt.show()
Explanation for `bee` (0.9229555130004883)
../_images/tutorials_rise_imagenet_20_1.png
Explanation for `cardoon` (0.03968876227736473)
../_images/tutorials_rise_imagenet_20_3.png
Explanation for `fly` (0.01597355678677559)
../_images/tutorials_rise_imagenet_20_5.png
Explanation for `dragonfly` (0.007476466707885265)
../_images/tutorials_rise_imagenet_20_7.png
Explanation for `garden_spider` (0.005400042049586773)
../_images/tutorials_rise_imagenet_20_9.png

3 - Conclusions

The relevance scores are generated by passing multiple randomly masked inputs to the black-box model and averaging their scores. The idea behind this is that whenever a mask preserves important parts of the image it gets higher score.

The example here shows that the RISE method evaluates the relevance of each pixel/super pixel to the classification. Pixels characterizing the bee are highlighted by the XAI approach, which gives an intuition on how the model classifies the image. The results are reasonable, based on the human visual preception of the image.