Interpreting a movie review sentiment model with RISE
This notebook demonstrates the use of DIANNA with the RISE method on the Stanford Sentiment Treebank dataset which contains one-sentence movie reviews. See also their paper. A pre-trained neural network classifier is used, which identifies whether a movie review is positive or negative.
RISE is short for Randomized Input Sampling for Explanation of Black-box Models. It estimates each word’s relevance to the model’s decision empirically by probing the model with randomly masked versions of the input image and obtaining the corresponding outputs.
More details about this method can be found in the paper https://arxiv.org/abs/1806.07421.
NOTE: This tutorial is still work-in-progress, the final results need to be improved by tweaking the RISE parameters
Colab Setup
[1]:
running_in_colab = 'google.colab' in str(get_ipython())
if running_in_colab:
# install dianna
!python3 -m pip install dianna[notebooks]
# download data used in this demo
import os
base_url = 'https://raw.githubusercontent.com/dianna-ai/dianna/main/dianna/'
paths_to_download = ['./data/movie_reviews_word_vectors.txt', './models/movie_review_model.onnx']
for path in paths_to_download:
!wget {base_url + path} -P {os.path.dirname(path)}
0 - Imports and paths
[1]:
import os
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import spacy
from torchtext.vocab import Vectors
from scipy.special import expit as sigmoid
import dianna
from dianna import visualization
from dianna import utils
from dianna.utils.tokenizers import SpacyTokenizer
[2]:
model_path = Path('..','..','..','dianna','models', 'movie_review_model.onnx')
word_vector_path = Path('..','..','..','dianna','labels', 'movie_reviews_word_vectors.txt')
labels = ("negative", "positive")
1 - Loading the model
[4]:
# ensure the tokenizer for english is available
spacy.cli.download('en_core_web_sm')
Collecting en-core-web-sm==3.2.0
Downloading https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.2.0/en_core_web_sm-3.2.0-py3-none-any.whl (13.9 MB)
━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.9/13.9 MB 2.2 MB/s eta 0:00:00
Requirement already satisfied: spacy<3.3.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from en-core-web-sm==3.2.0) (3.2.4)
Requirement already satisfied: click<8.1.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.4)
Requirement already satisfied: blis<0.8.0,>=0.4.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.7.7)
Requirement already satisfied: numpy>=1.15.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.21.6)
Requirement already satisfied: jinja2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.1.1)
Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.8 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.9)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.6)
Requirement already satisfied: packaging>=20.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (21.3)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.6)
Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3.0)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.27.1)
Requirement already satisfied: pathy>=0.3.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.6.1)
Requirement already satisfied: setuptools in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (62.1.0)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.6)
Requirement already satisfied: thinc<8.1.0,>=8.0.12 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (8.0.15)
Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.7)
Requirement already satisfied: srsly<3.0.0,>=2.4.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.4.3)
Requirement already satisfied: typer<0.5.0,>=0.3.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.4.1)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.64.0)
Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.0.2)
Requirement already satisfied: wasabi<1.1.0,>=0.8.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (0.9.1)
Requirement already satisfied: pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.8.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from packaging>=20.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.0.8)
Requirement already satisfied: smart-open<6.0.0,>=5.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pathy>=0.3.5->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (5.2.1)
Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from pydantic!=1.8,!=1.8.1,<1.9.0,>=1.7.4->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (4.1.1)
Requirement already satisfied: idna<4,>=2.5 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (3.3)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (1.26.9)
Requirement already satisfied: charset-normalizer~=2.0.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.0.12)
Requirement already satisfied: certifi>=2017.4.17 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from requests<3.0.0,>=2.13.0->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2021.10.8)
Requirement already satisfied: MarkupSafe>=2.0 in /opt/homebrew/Caskroom/miniforge/base/envs/dianna/lib/python3.9/site-packages (from jinja2->spacy<3.3.0,>=3.2.0->en-core-web-sm==3.2.0) (2.1.1)
✔ Download and installation successful
You can now load the package via spacy.load('en_core_web_sm')
[3]:
class MovieReviewsModelRunner:
def __init__(self, model, word_vectors, max_filter_size):
self.run_model = utils.get_function(model)
self.vocab = Vectors(word_vectors, cache=os.path.dirname(word_vectors))
self.max_filter_size = max_filter_size
self.tokenizer = SpacyTokenizer(name='en_core_web_sm')
def __call__(self, sentences):
# ensure the input has a batch axis
if isinstance(sentences, str):
sentences = [sentences]
output = []
for sentence in sentences:
# tokenize and pad to minimum length
tokens = self.tokenizer.tokenize(sentence)
if len(tokens) < self.max_filter_size:
tokens += ['<pad>'] * (self.max_filter_size - len(tokens))
# numericalize the tokens
tokens_numerical = [self.vocab.stoi[token] if token in self.vocab.stoi else self.vocab.stoi['<unk>']
for token in tokens]
# run the model, applying a sigmoid because the model outputs logits, remove any remaining batch axis
pred = float(sigmoid(self.run_model([tokens_numerical])))
output.append(pred)
# output two classes
positivity = np.array(output)
negativity = 1 - positivity
return np.transpose([negativity, positivity])
[4]:
# define model runner. max_filter_size is a property of the model
model_runner = MovieReviewsModelRunner(model_path, word_vector_path, max_filter_size=5)
2 - Applying RISE with DIANNA
The simplest way to run DIANNA on text data is with dianna.explain_text
. The arguments are: * The function that runs the model (a path to a model in ONNX format is also accepted) * The text we want to explain * The name of the explainable-AI method we want to use, here RISE * The numerical indices of the classes we want an explanation for
dianna.explain_text
returns a list of tuples. Each tuple contains a word, its location in the input text, and its relevance for the selected output class
[5]:
review = "A delectable and intriguing thriller filled with surprises"
[6]:
# An explanation is returned for each label, but we ask for just one label so the output is a list of length one.
explanation_relevances = dianna.explain_text(model_runner, review, model_runner.tokenizer, 'RISE',
labels=[labels.index('positive')])[0]
explanation_relevances
Explaining: 0%| | 0/10 [00:00<?, ?it/s]/tmp/ipykernel_111710/2950338841.py:26: DeprecationWarning: Conversion of an array with ndim > 0 to a scalar is deprecated, and will error in future. Ensure you extract a single element from your array before performing this operation. (Deprecated NumPy 1.25.)
pred = float(sigmoid(self.run_model([tokens_numerical])))
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.25it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.21it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.21it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.18it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.13it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00, 1.05it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:09<00:00, 1.08it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.15it/s]
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:08<00:00, 1.13it/s]
Rise parameter p_keep was automatically determined at 0.2
Explaining: 100%|█████████████████████████████████████████████████████████████████| 10/10 [00:10<00:00, 1.08s/it]
[6]:
[('A', 0, 0.7278900909423828),
('delectable', 1, 0.8564747416973114),
('and', 2, 0.6521024966239928),
('intriguing', 3, 0.9845284354686736),
('thriller', 4, 0.8320333343744277),
('filled', 5, 0.5821573086082935),
('with', 6, 0.7066416105628013),
('surprises', 7, 0.7483201465010643)]
3 - Visualization
DIANNA includes a visualization package, capable of highlighting each word of a text based on their relevance scores. The visualization is in HTML format. In this visualization, words in favour of the selected class are highlighted in red. Words against the selected class are not present in this example, otherwise they would be highlighted in blue.
[7]:
fig, _ = visualization.highlight_text(explanation_relevances, model_runner.tokenizer.tokenize(review))
The visualization is not very clear, as all words seem relevant for the review’s outcome. From the numerical values above, we see that indeed all words contribute positively according to RISE, with “intriguing” as the most important word with a score of 0.94.