dianna.utils ============ .. py:module:: dianna.utils .. autoapi-nested-parse:: DIANNA utilities. Submodules ---------- .. toctree:: :maxdepth: 1 /autoapi/dianna/utils/downloader/index /autoapi/dianna/utils/maskers/index /autoapi/dianna/utils/misc/index /autoapi/dianna/utils/onnx_runner/index /autoapi/dianna/utils/predict/index /autoapi/dianna/utils/rise_utils/index /autoapi/dianna/utils/tokenizers/index Classes ------- .. autoapisummary:: dianna.utils.SimpleModelRunner Functions --------- .. autoapisummary:: dianna.utils.get_function dianna.utils.get_kwargs_applicable_to_function dianna.utils.locate_channels_axis dianna.utils.move_axis dianna.utils.onnx_model_node_loader dianna.utils.to_xarray Package Contents ---------------- .. py:function:: get_function(model_or_function, preprocess_function=None) Converts input to callable function. Any keyword arguments are given to the ModelRunner class if the input is a model path. :param model_or_function: Can be either model path or function. If input is a function, the function is returned unchanged. :param preprocess_function: function to be run to preprocess the data .. py:function:: get_kwargs_applicable_to_function(function, kwargs) Returns a subset of `kwargs` of only arguments and keyword arguments of `function`. Note that if `function` has a `**kwargs` argument, this function should not be necessary (provided the function handles `**kwargs` robustly). .. py:function:: locate_channels_axis(data_shape) Determine index of (colour) channels axis in input data. The channels axis is assumed to have size 3 (for colour images) or 1 (for greyscale images). An error is raised if this is not the case or the channels axis could not be found. :param data_shape: The shape of one data item, without a batch axis :type data_shape: tuple :returns: 0 or -1 indicating the index of the channels axis. .. py:function:: move_axis(data, label, new_position) Moves a named axis to a new position in an xarray DataArray object. :param data: Object with named axes :type data: DataArray :param label: Name of the axis to move :type label: str :param new_position: Numerical new position of the axis. Negative indices are accepted. :type new_position: int :returns: data with axis in new position .. py:function:: onnx_model_node_loader(model_path) Onnx model and node labels loader. Load onnx model and return the labels of its input/output nodes and the data type of input node. :param model_path: The path to a ONNX model on disk. :type model_path: str :returns: A 4-tuple of: - onnx_model (onnx.ModelProto): The loaded ONNX model. - label_input_node (str): Name of the first input node. - dtype_input_node (numpy.dtype): Numpy dtype of the first input node. - label_output_node (str): Name of the first output node. :rtype: tuple .. py:function:: to_xarray(data: numpy.typing.ArrayLike, axis_labels, required_labels=None) Converts numpy data and axes labels to an xarray object. .. py:class:: SimpleModelRunner(filename, preprocess_function=None) Runs an onnx model with a set of inputs and outputs. .. py:attribute:: filename .. py:attribute:: preprocess_function :value: None .. py:method:: __call__(input_data) Get ONNX predictions.