import numpy as np
import seaborn as sns
import streamlit as st
from _model_utils import load_data
from _model_utils import load_labels
from _model_utils import load_model
from _model_utils import load_penguins
from _model_utils import load_sunshine
from _model_utils import load_training_data
from _models_tabular import explain_tabular_dispatcher
from _models_tabular import predict
from _shared import _get_top_indices_and_labels
from _shared import _methods_checkboxes
from _shared import add_sidebar_logo
from _shared import reset_example
from _shared import reset_method
from st_aggrid import AgGrid
from st_aggrid import GridOptionsBuilder
from st_aggrid import GridUpdateMode
from dianna.utils.downloader import download
from dianna.visualization import plot_tabular
add_sidebar_logo()
[docs]
def description_explainer(open='open'):
"""Expandable text section with image."""
return (st.markdown(
f"""
<details {open}>
<summary><b>Description of the explanation</b></summary>
The explanation is visualised as a **relevance bar-chart** for the top (up to 10) most
relevant _attributes (features)_. <br>
The chart displays the relevance _attributions_ of the individual features of the tabular data
to a **pretrained model**'s classification or regression prediciton.
The attribution chart can be computed for any predicted outcome.
The attribution colormap
assigns :blue[**blue**] color to negative relevances,
and :red[**red**] color to positive values.
</details>
""",
unsafe_allow_html=True
),
st.text("")
)
st.title('Explaining Tabular data classification/regression')
st.sidebar.header('Input data')
# Use the examples
if input_type == 'Use an example':
[docs]
load_example = st.sidebar.radio(
label='Use example',
options=('Sunshine hours prediction (regression)', 'Penguin identification (classification)'),
index = None,
on_change = reset_method,
key='Tabular_load_example')
if load_example == "Sunshine hours prediction (regression)":
tabular_data_file = download('weather_prediction_dataset_light.csv', 'data')
tabular_model_file = download('sunshine_hours_regression_model.onnx', 'model')
tabular_training_data_file = tabular_data_file
tabular_label_file = None
training_data, data = load_sunshine(tabular_data_file)
labels = None
mode = 'regression'
description_explainer("")
st.markdown(
"""
*****************************************************************************
This example demonstrates the use of DIANNA on a pre-trained [regression
model](https://zenodo.org/records/10580833) to predict tomorrow's sunshine hours
based on meteorological data from today.
The model is trained on the
[weather prediction dataset](https://zenodo.org/records/5071376). <br>
The meteorological data includes measurements (features) of
_cloud coverage, humidity, air pressure, global radiation, precipitation_, and
_mean, min_ and _max temeprature_
for various European cities.
""",
unsafe_allow_html=True )
elif load_example == 'Penguin identification (classification)':
tabular_model_file = download('penguin_model.onnx', 'model')
data_penguins = sns.load_dataset('penguins')
labels = data_penguins['species'].unique()
training_data, data = load_penguins(data_penguins)
mode = 'classification'
description_explainer("")
st.markdown(
"""
****************************************************************************
This example demonstrates the use of DIANNA on a pre-trained [classification
model](https://zenodo.org/records/10580743) to identify if a penguin belongs to one of three different species
based on a number of measurable physical characteristics. <br>
The model is trained on the
[penguin dataset](https://www.kaggle.com/code/parulpandey/penguin-dataset-the-new-iris).
The penguin characteristics include the _bill length_, _bill depth_, _flipper length_, and _body mass_.
""",
unsafe_allow_html=True)
else:
description_explainer()
st.info('Select an example in the left panel to coninue')
st.stop()
# Option to upload your own data
if input_type == 'Use your own data':
[docs]
tabular_data_file = st.sidebar.file_uploader('Select tabular data', type='csv')
tabular_model_file = st.sidebar.file_uploader('Select model', type='onnx')
tabular_training_data_file = st.sidebar.file_uploader('Select training data', type='npy')
tabular_label_file = st.sidebar.file_uploader('Select labels in case of classification model', type='txt')
if not (tabular_data_file and tabular_model_file and tabular_training_data_file):
description_explainer()
st.info('Add your input data in the left panel to continue')
st.stop()
else:
description_explainer("")
data = load_data(tabular_data_file)
model = load_model(tabular_model_file)
training_data = load_training_data(tabular_training_data_file)
if tabular_label_file:
labels = load_labels(tabular_label_file)
mode = 'classification'
else:
labels = None
mode = 'regression'
if input_type is None:
description_explainer()
st.info('Select which input type to use in the left panel to continue')
st.stop()
[docs]
model = load_model(tabular_model_file)
[docs]
serialized_model = model.SerializeToString()
[docs]
choices = ('RISE', 'LIME', 'KernelSHAP')
st.text("")
# Get predictions and create parameter box
with st.container(border=True):
[docs]
prediction_placeholder = st.empty()
methods, method_params = _methods_checkboxes(choices=choices, key='Tabular_cb')
# Configure Ag-Grid options
[docs]
gb = GridOptionsBuilder.from_dataframe(data)
gb.configure_selection('single')
[docs]
grid_options = gb.build()
# Display the grid with the DataFrame
[docs]
grid_response = AgGrid(
data,
gridOptions=grid_options,
update_mode=GridUpdateMode.SELECTION_CHANGED,
theme='streamlit'
)
if grid_response['selected_rows'] is not None:
[docs]
selected_row = int(grid_response['selected_rows'].index[0])
selected_data = data.iloc[selected_row].to_numpy()[1:]
with st.spinner('Predicting class'):
predictions = predict(model=serialized_model, tabular_input=selected_data.reshape(1,-1))
with prediction_placeholder:
top_indices, top_labels = _get_top_indices_and_labels(
predictions=predictions[0], labels=labels)
else:
st.info("Select the input data by clicking a row in the table.")
st.stop()
st.text("")
st.text("")
[docs]
weight = 0.85 / len(methods)
[docs]
column_spec = [0.15, *[weight for _ in methods]]
_, *columns = st.columns(column_spec)
for col, method in zip(columns, methods):
col.markdown(f"<h4 style='text-align: center; '>{method}</h4>", unsafe_allow_html=True)
for index, label in zip(top_indices, top_labels):
index_col, *columns = st.columns(column_spec)
if mode == 'classification':
index_col.markdown(f'##### Class: {label}')
for col, method in zip(columns, methods):
[docs]
kwargs = method_params[method].copy()
kwargs['mode'] = mode
kwargs['_feature_names']=data.columns.to_list()[1:]
func = explain_tabular_dispatcher[method]
with col:
with st.spinner(f'Running {method}'):
relevances = func(serialized_model, selected_data, training_data, **kwargs)
if mode == 'classification':
plot_relevances = relevances[np.argmax(predictions)]
else:
plot_relevances = relevances
fig, _ = plot_tabular(x=plot_relevances, y=kwargs['_feature_names'],
num_features=10, show_plot=False)
st.pyplot(fig)
# add some white space to separate rows
st.markdown('')
st.stop()