Source code for _models_tabular

import numpy as np
import onnxruntime as ort
import streamlit as st
from dianna import explain_tabular


@st.cache_data
[docs] def predict(*, model, tabular_input): # Make sure that tabular input is provided as float32 sess = ort.InferenceSession(model) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name onnx_input = {input_name: tabular_input.astype(np.float32)} pred_onnx = sess.run([output_name], onnx_input)[0] return pred_onnx
@st.cache_data
[docs] def _run_rise_tabular(_model, table, training_data,_feature_names, **kwargs): # convert streamlit kwarg requirement back to dianna kwarg requirement if "_preprocess_function" in kwargs: kwargs["preprocess_function"] = kwargs["_preprocess_function"] del kwargs["_preprocess_function"] def run_model(tabular_input): return predict(model=_model, tabular_input=tabular_input) relevances = explain_tabular( run_model, table, method='RISE', training_data=training_data, feature_names=_feature_names, **kwargs, ) return relevances
@st.cache_data
[docs] def _run_lime_tabular(_model, table, training_data, _feature_names, **kwargs): # convert streamlit kwarg requirement back to dianna kwarg requirement if "_preprocess_function" in kwargs: kwargs["preprocess_function"] = kwargs["_preprocess_function"] del kwargs["_preprocess_function"] def run_model(tabular_input): return predict(model=_model, tabular_input=tabular_input) relevances = explain_tabular( run_model, table, method='LIME', training_data=training_data, feature_names=_feature_names, **kwargs, ) return relevances
@st.cache_data
[docs] def _run_kernelshap_tabular(model, table, training_data, _feature_names, **kwargs): # Kernelshap interface is different. Write model to temporary file. if "_preprocess_function" in kwargs: kwargs["preprocess_function"] = kwargs["_preprocess_function"] del kwargs["_preprocess_function"] def run_model(tabular_input): return predict(model=model, tabular_input=tabular_input) relevances = explain_tabular(run_model, table, method='KernelSHAP', training_data=training_data, feature_names=_feature_names, **kwargs) return np.array(relevances)
[docs] explain_tabular_dispatcher = { 'RISE': _run_rise_tabular, 'LIME': _run_lime_tabular, 'KernelSHAP': _run_kernelshap_tabular }