Source code for _models_ts

import numpy as np
import onnxruntime as ort
import streamlit as st
import dianna


@st.cache_data
[docs] def predict(*, model, ts_data): # model must receive data in the order of [batch, timeseries, channels] # data = data.transpose([0,2,1]) # get ONNX predictions sess = ort.InferenceSession(model) input_name = sess.get_inputs()[0].name output_name = sess.get_outputs()[0].name onnx_input = {input_name: ts_data.astype(np.float32)} pred_onnx = sess.run([output_name], onnx_input)[0] return pred_onnx
@st.cache_data
[docs] def _run_rise_timeseries(_model, ts_data, **kwargs): # convert streamlit kwarg requirement back to dianna kwarg requirement if "_preprocess_function" in kwargs: kwargs["preprocess_function"] = kwargs["_preprocess_function"] del kwargs["_preprocess_function"] def run_model(ts_data): return predict(model=_model, ts_data=ts_data) explanation = dianna.explain_timeseries( run_model, input_timeseries=ts_data[0], method='RISE', **kwargs, ) return explanation
@st.cache_data
[docs] def _run_lime_timeseries(_model, ts_data, **kwargs): # convert streamlit kwarg requirement back to dianna kwarg requirement if "_preprocess_function" in kwargs: kwargs["preprocess_function"] = kwargs["_preprocess_function"] del kwargs["_preprocess_function"] def run_model(ts_data): return predict(model=_model, ts_data=ts_data) explanation = dianna.explain_timeseries( run_model, ts_data[0], method='LIME', num_features=len(ts_data[0]), num_slices=len(ts_data[0]), distance_method='dtw', **kwargs, ) return explanation
[docs] explain_ts_dispatcher = { 'RISE': _run_rise_timeseries, 'LIME': _run_lime_timeseries, }