Logo_ER10

Interpreting a movie review sentiment model with RISE

This notebook demonstrates the use of DIANNA with the RISE method on the Stanford Sentiment Treebank dataset which contains one-sentence movie reviews. See also their paper. A pre-trained neural network classifier is used, which identifies whether a movie review is positive or negative.

RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates each word’s relevance to the model’s decision empirically by probing the model with randomly masked versions of the input image and obtaining the corresponding outputs.

More details about this method can be found in the paper https://arxiv.org/abs/1806.07421.

NOTE: This tutorial is still work-in-progress, the final results need to be improved by tweaking the RISE parameters

Colab Setup

[1]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
  # install dianna
  !python3 -m pip install dianna[notebooks]

  # download data used in this demo
  import os
  base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
  paths_to_download = ['./data/movie_reviews_word_vectors.txt', './models/movie_review_model.onnx']
  for path in paths_to_download:
      !wget {base_url + path} -P {os.path.dirname(path)}

0 - Imports and paths

[1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import spacy
from torchtext.vocab import Vectors
from scipy.special import expit as sigmoid

import dianna
from dianna import visualization
from dianna import utils
from dianna.utils.tokenizers import SpacyTokenizer
[2]:
model_path = Path('..','..','..','dianna','models', 'movie_review_model.onnx')
word_vector_path = Path('..','..','..','dianna','labels', 'movie_reviews_word_vectors.txt')
labels = ("negative", "positive")

1 - Loading the model

The classifier is stored in ONNX format. It accepts numerical tokens as input, and outputs a score between 0 and 1, where 0 means the review is negative and 1 that it is positive.
Here we define a class to run the model, which accepts a sentence (i.e. string) as input instead and returns two classes: negative and positive.
[4]:
# ensure the tokenizer for english is available
spacy.cli.download('en_core_web_sm')
Collecting en-core-web-sm==3.2.0
  Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.9/13.9 MB 2.2 MB/s eta 0:00:00
Requirement already satisfied: spacy<3.3.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from en-core-web-sm==3.2.0) (3.2.4)
Requirement already satisfied: click<8.1.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.4)
Requirement already satisfied: blis<0.8.0,>=0.4.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.7.7)
Requirement already satisfied: numpy>=1.15.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.21.6)
Requirement already satisfied: jinja2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.1.1)
Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.9)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.6)
Requirement already satisfied: packaging>=20.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (21.3)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.6)
Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3.0)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.27.1)
Requirement already satisfied: pathy>=0.3.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.6.1)
Requirement already satisfied: setuptools in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (62.1.0)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.6)
Requirement already satisfied: thinc<8.1.0,>=8.0.12 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.15)
Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.7)
Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.4.3)
Requirement already satisfied: typer<0.5.0,>=0.3.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.4.1)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.64.0)
Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.2)
Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.9.1)
Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.8.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from packaging>=20.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.8)
Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pathy>=0.3.5->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (5.2.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.1.1)
Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.26.9)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.12)
Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2021.10.8)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from jinja2->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.1.1)
✔ Download and installation successful
You can now load the package via spacy.load('en_core_web_sm')
[3]:
class MovieReviewsModelRunner:
    def __init__(self, model, word_vectors, max_filter_size):
        self.run_model = utils.get_function(model)
        self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
        self.max_filter_size = max_filter_size

        self.tokenizer = SpacyTokenizer(name='en_core_web_sm')

    def __call__(self, sentences):
        # ensure the input has a batch axis
        if isinstance(sentences, str):
            sentences = [sentences]

        output = []
        for sentence in sentences:
            # tokenize and pad to minimum length
            tokens = self.tokenizer.tokenize(sentence)
            if len(tokens) < self.max_filter_size:
                tokens += ['<pad>'] * (self.max_filter_size - len(tokens))

            # numericalize the tokens
            tokens_numerical = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi['<unk>']
                                for token in tokens]

            # run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis
            pred = float(sigmoid(self.run_model([tokens_numerical])))
            output.append(pred)

        # output two classes
        positivity = np.array(output)
        negativity = 1 - positivity
        return np.transpose([negativity, positivity])
[4]:
# define model runner. max_filter_size is a property of the model
model_runner = MovieReviewsModelRunner(model_path, word_vector_path, max_filter_size=5)

2 - Applying RISE with DIANNA

The simplest way to run DIANNA on text data is with dianna.explain_text. The arguments are: * The function that runs the model (a path to a model in ONNX format is also accepted) * The text we want to explain * The name of the explainable-AI method we want to use, here RISE * The numerical indices of the classes we want an explanation for

dianna.explain_text returns a list of tuples. Each tuple contains a word, its location in the input text, and its relevance for the selected output class

[5]:
review = "A delectable and intriguing thriller filled with surprises"
[6]:
# An explanation is returned for each label, but we ask for just one label so the output is a list of length one.
explanation_relevances =  dianna.explain_text(model_runner, review, model_runner.tokenizer, 'RISE',
                                              labels=[labels.index('positive')])[0]
explanation_relevances
Explaining:   0%|                                                                          | 0/10 [00:00<?, ?it/s]/tmp/ipykernel_111710/2950338841.py:26: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
  pred = float(sigmoid(self.run_model([tokens_numerical])))
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.25it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.21it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.18it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.13it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.05it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00,  1.08it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.15it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00,  1.13it/s]
Rise parameter p_keep was automatically determined at 0.2
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:10<00:00,  1.08s/it]
[6]:
[('A', 0, 0.7278900909423828),
 ('delectable', 1, 0.8564747416973114),
 ('and', 2, 0.6521024966239928),
 ('intriguing', 3, 0.9845284354686736),
 ('thriller', 4, 0.8320333343744277),
 ('filled', 5, 0.5821573086082935),
 ('with', 6, 0.7066416105628013),
 ('surprises', 7, 0.7483201465010643)]

3 - Visualization

DIANNA includes a visualization package, capable of highlighting each word of a text based on their relevance scores. The visualization is in HTML format. In this visualization, words in favour of the selected class are highlighted in red. Words against the selected class are not present in this example, otherwise they would be highlighted in blue.

[7]:
fig, _ = visualization.highlight_text(explanation_relevances, model_runner.tokenizer.tokenize(review))
../_images/tutorials_rise_text_14_0.png

The visualization is not very clear, as all words seem relevant for the review’s outcome. From the numerical values above, we see that indeed all words contribute positively according to RISE, with “intriguing” as the most important word with a score of 0.94.