dianna.utils.predict

Utility functions for perturbation based predictions.

Module Contents

Functions

make_predictions(data, runner, batch_size)

Make predictions with the input data.

dianna.utils.predict.make_predictions(data, runner, batch_size)[source]

Make predictions with the input data.

Process the data with the model runner and return the predictions.

Parameters:
  • data (np.ndarray) – An array of masked input data to be processed by the model.

  • runner (object) – An object that runs the model on the input data and returns predictions.

  • batch_size (int) – The number of masked inputs to process in each batch.

Returns:

An array of predictions made by the model on the input data.

Return type:

np.ndarray