dianna.utils
DIANNA utilities.
Submodules
Package Contents
Classes
Runs an onnx model with a set of inputs and outputs. |
Functions
|
Converts input to callable function. |
|
Returns a subset of kwargs of only arguments and keyword arguments of function. |
|
Determine index of (colour) channels axis in input data. |
|
Moves a named axis to a new position in an xarray DataArray object. |
|
Onnx model and node labels loader. |
|
Converts numpy data and axes labels to an xarray object. |
- dianna.utils.get_function(model_or_function, preprocess_function=None)[source]
Converts input to callable function.
Any keyword arguments are given to the ModelRunner class if the input is a model path.
- Parameters:
model_or_function – Can be either model path or function. If input is a function, the function is returned unchanged.
preprocess_function – function to be run to preprocess the data
- dianna.utils.get_kwargs_applicable_to_function(function, kwargs)[source]
Returns a subset of kwargs of only arguments and keyword arguments of function.
Note that if function has a **kwargs argument, this function should not be necessary (provided the function handles **kwargs robustly).
- dianna.utils.locate_channels_axis(data_shape)[source]
Determine index of (colour) channels axis in input data.
The channels axis is assumed to have size 3 (for colour images) or 1 (for greyscale images). An error is raised if this is not the case or the channels axis could not be found.
- Parameters:
data_shape (tuple) – The shape of one data item, without a batch axis
- Returns:
0 or -1 indicating the index of the channels axis.
- dianna.utils.move_axis(data, label, new_position)[source]
Moves a named axis to a new position in an xarray DataArray object.
- dianna.utils.onnx_model_node_loader(model_path)[source]
Onnx model and node labels loader.
Load onnx model and return the label of its output node and the data type of input node.
- Parameters:
model_path (str) – The path to a ONNX model on disk.
- Returns:
loaded onnx model and the label of output node.