Source code for _image_utils

import numpy as np
from keras import utils as keras_utils
from PIL import Image
from PIL import ImageStat
from dianna.utils import move_axis
from dianna.utils import to_xarray


[docs] def preprocess_img_resnet(path): """Resnet specific function for preprocessing. Reshape figure to 224,224 and get colour channel at position 0. Also: for resnet preprocessing: normalize the data. This works specifically for ImageNet. See: https://github.com/onnx/models/tree/main/vision/classification/resnet """ img = keras_utils.load_img(path, target_size=(224, 224)) img_data = keras_utils.img_to_array(img) if img_data.shape[0] != 3: # Colour channel is not in position 0; reshape the data xarray = to_xarray(img_data, {0: 'height', 1: 'width', 2: 'channels'}) reshaped_data = move_axis(xarray, 'channels', 0) img_data = np.array(reshaped_data) # definitions for normalisation (for ImageNet) mean_vec = np.array([0.485, 0.456, 0.406]) stddev_vec = np.array([0.229, 0.224, 0.225]) norm_img_data = np.zeros(img_data.shape).astype('float32') for i in range(img_data.shape[0]): # for each pixel in each channel, divide the values by 255 ([0,1]), and normalize # using mean and standard deviation from values above norm_img_data[i, :, :] = (img_data[i, :, :] / 255 - mean_vec[i]) / stddev_vec[i] return norm_img_data, img
[docs] def open_image(file): """Open an image from a file and returns it as a numpy array.""" im = Image.open(file).convert('RGB') stat = ImageStat.Stat(im) im = np.asarray(im).astype(np.float32) if sum(stat.sum ) / 3 == stat.sum[0]: # check the avg with any element value return np.expand_dims(im[:, :, 0], axis=2) / 255, im # if grayscale else: # else it's colour, reshape to 224x224x3 for resnet return preprocess_img_resnet(file)