import tempfile
import onnxruntime as rt
import streamlit as st
from _model_utils import preprocess_function
from dianna import explain_image
@st.cache_data
[docs]
def predict(*, model, image):
session = rt.InferenceSession(model.SerializeToString())
output_node = session.get_outputs()[0].name
input_node = session.get_inputs()[0].name
predictions = session.run([output_node], {input_node: image[None, ...]})[0]
return predictions[0]
@st.cache_data
[docs]
def _run_rise_image(model, image, i, **kwargs):
relevances = explain_image(
model,
image,
method='RISE',
**kwargs,
)
return relevances[0]
@st.cache_data
[docs]
def _run_lime_image(model, image, i, **kwargs):
relevances = explain_image(
model,
image * 256,
preprocess_function=preprocess_function,
method='LIME',
**kwargs,
)
return relevances[0]
@st.cache_data
[docs]
def _run_kernelshap_image(model, image, i, **kwargs):
# Kernelshap interface is different. Write model to temporary file.
with tempfile.NamedTemporaryFile() as f:
f.write(model)
f.flush()
relevances = explain_image(f.name,
image,
method='KernelSHAP',
**kwargs)
return relevances[0]
[docs]
explain_image_dispatcher = {
'RISE': _run_rise_image,
'LIME': _run_lime_image,
'KernelSHAP': _run_kernelshap_image,
}