import base64
import sys
import streamlit as st
from _image_utils import open_image
from _model_utils import load_labels
from _model_utils import load_model
from _models_image import explain_image_dispatcher
from _models_image import predict
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import reset_example
from _shared import reset_method
from dianna.utils.downloader import download
from dianna.visualization import plot_image
if sys.version_info < (3, 10):
from importlib_resources import files
else:
from importlib.resources import files
[docs]
data_directory = files('dianna.data')
[docs]
colormap_path = str(data_directory / 'colormap.png')
with open(colormap_path, "rb") as img_file:
[docs]
colormap = base64.b64encode(img_file.read()).decode()
[docs]
def description_explainer(open='open'):
"""Expandable text section with image."""
return (st.markdown(
f"""
<details {open}>
<summary><b>Description of the explanation</b></summary>
The explanation is visualised as a **relevance heatmap** overlayed on top of the time series. <br>
The heatmap consists of the relevance _attributions_ of all individual pixels/super-pixels of the image
to a **pretrained model**'s classification. <br>
The attribution heatmap can be computed for any class. <br><br>
The _bwr (blue white red)_ attribution colormap
assigns :blue[**blue**] color to negative relevances, **white** color to near-zero values,
and :red[**red**] color to positive values. <br><br>
<img src="data:image/png;base64,{colormap}" alt="Colormap" width="600" ><br>
</details>
""",
unsafe_allow_html=True
),
st.text("")
)
add_sidebar_logo()
st.title('Explaining Image data classification')
st.sidebar.header('Input data')
# Use the examples
if input_type == 'Use an example':
load_example = st.sidebar.radio(
label='Load example',
options=('Hand-written digit recognition',),
index = None,
on_change = reset_method,
key='Image_load_example'
)
if load_example == 'Hand-written digit recognition':
image_file = download('digit0.jpg', 'data')
image_model_file = download('mnist_model_tf.onnx', 'model')
image_label_file = download('labels_mnist.txt', 'label')
imagekey = 'Digits_Image_cb'
description_explainer("")
st.markdown(
"""
*********************************************************************************************
This example demonstrates the use of DIANNA on explaining a
[**binary MNIST model**](https://zenodo.org/records/5907177) pretrained on **only** images of
the hand-written digits 0 and 1. <br>
The model classifies an image of a hand-written digit as displaying 0 or 1.
""",
unsafe_allow_html=True
)
else:
description_explainer()
st.info('Select an example in the left panel to coninue')
st.stop()
# Option to upload your own data
if input_type == 'Use your own data':
image_file = st.sidebar.file_uploader('Select image',
type=('png', 'jpg', 'jpeg'))
if image_file:
st.sidebar.image(image_file)
image_model_file = st.sidebar.file_uploader('Select model',
type='onnx')
image_label_file = st.sidebar.file_uploader('Select labels',
type='txt')
imagekey = 'Image_cb'
if not (image_file and image_model_file and image_label_file):
description_explainer()
st.info('Add your input data in the left panel to continue')
st.stop()
else:
description_explainer("")
if input_type is None:
description_explainer()
st.info('Select which input type to use in the left panel to continue')
st.stop()
image, _ = open_image(image_file)
[docs]
model = load_model(image_model_file)
[docs]
serialized_model = model.SerializeToString()
[docs]
labels = load_labels(image_label_file)
[docs]
choices = ('RISE', 'KernelSHAP', 'LIME')
st.text("")
st.text("")
with st.container(border=True):
[docs]
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key=imagekey)
with st.spinner('Predicting class'):
predictions = predict(model=model, image=image)
with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
predictions=predictions,labels=labels)
st.text("")
st.text("")
# check which axis is color channel
[docs]
original_data = image[:, :, 0] if image.shape[2] <= 3 else image[1, :, :]
[docs]
axis_labels = {2: 'channels'} if image.shape[2] <= 3 else {0: 'channels'}
[docs]
weight = 0.9 / len(methods)
[docs]
column_spec = [0.1, *[weight for _ in methods]]
_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
col.markdown(f"<h4 style='text-align: center; '>{method}</h4>", unsafe_allow_html=True)
for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
index_col.markdown(f'##### Class: {label}')
for col, method in zip(columns, methods):
[docs]
kwargs = method_params[method].copy()
kwargs['axis_labels'] = axis_labels
kwargs['labels'] = [index]
func = explain_image_dispatcher[method]
with col:
with st.spinner(f'Running {method}'):
heatmap = func(serialized_model, image, index, **kwargs)
fig, _ = plot_image(heatmap,
original_data=original_data,
heatmap_cmap='bwr',
show_plot=False)
st.pyplot(fig)
st.stop()