import onnxruntime as ort
[docs]
class SimpleModelRunner:
"""Runs an onnx model with a set of inputs and outputs."""
def __init__(self, filename, preprocess_function=None):
"""Generates function to run ONNX model with one set of inputs and outputs.
Args:
filename (str): Path to ONNX model on disk
preprocess_function (callable, optional): Function to preprocess input data with
Returns:
function
Examples:
>>> runner = SimpleModelRunner('path_to_model.onnx')
>>> predictions = runner(input_data)
"""
[docs]
self.filename = filename
[docs]
self.preprocess_function = preprocess_function
[docs]
def __call__(self, input_data):
"""Get ONNX predictions."""
sess_options = ort.SessionOptions()
sess_options.enable_cpu_mem_arena = False # disables pre-allocation of memory
sess = ort.InferenceSession(self.filename, sess_options=sess_options)
input_name = sess.get_inputs()[0].name
output_name = sess.get_outputs()[0].name
if self.preprocess_function is not None:
input_data = self.preprocess_function(input_data)
onnx_input = {input_name: input_data}
pred_onnx = sess.run([output_name], onnx_input)[0]
return pred_onnx