Source code for _models_image

import tempfile
import onnxruntime as rt
import streamlit as st
from _model_utils import preprocess_function
from dianna import explain_image


@st.cache_data
[docs] def predict(*, model, image): session = rt.InferenceSession(model.SerializeToString()) output_node = session.get_outputs()[0].name input_node = session.get_inputs()[0].name predictions = session.run([output_node], {input_node: image[None, ...]})[0] return predictions[0]
@st.cache_data
[docs] def _run_rise_image(model, image, i, **kwargs): relevances = explain_image( model, image, method='RISE', **kwargs, ) return relevances[0]
@st.cache_data
[docs] def _run_lime_image(model, image, i, **kwargs): relevances = explain_image( model, image * 256, preprocess_function=preprocess_function, method='LIME', **kwargs, ) return relevances[0]
@st.cache_data
[docs] def _run_kernelshap_image(model, image, i, **kwargs): # Kernelshap interface is different. Write model to temporary file. with tempfile.NamedTemporaryFile() as f: f.write(model) f.flush() relevances = explain_image(f.name, image, method='KernelSHAP', **kwargs) return relevances[0]
[docs] explain_image_dispatcher = { 'RISE': _run_rise_image, 'LIME': _run_lime_image, 'KernelSHAP': _run_kernelshap_image, }