Source code for sksurgerytf.callbacks.segmentation_history

# -*- coding: utf-8 -*-

"""
Module to implement callback to save an image, with segmentation.
"""
import numpy as np
from tensorflow import keras
import tensorflow as tf

#pylint:disable=super-with-arguments
[docs]class SegmentationHistory(keras.callbacks.Callback): """ Class to implement Tensorboard callback to save a batch of images and their segmentations, so we can monitor progress directly in Tensorboard. """ def __init__(self, tensor_board_dir, data, number_of_samples, desired_number_images): """ Constructor. :param tensor_board_dir: directory to log to :param data: an ImageDataGenerator :param number_of_samples: number of samples coming from generator. :param desired_number_images: the number of images you want logging. """ super(SegmentationHistory, self).__init__() if number_of_samples <= 0: raise ValueError('number_of_samples must be > 0') if desired_number_images < 1: raise ValueError('desired_number_images must be >= 1') self.tensor_board_dir = tensor_board_dir self.data = data self.number_of_samples = number_of_samples self.desired_number_images = desired_number_images self.modulo = number_of_samples // desired_number_images # pylint: disable=unused-argument #pylint:disable=signature-differs
[docs] def on_epoch_end(self, epoch, logs): """ Called at the end of each epoch, so we can log data. :param epoch: number of the epoch :param logs: logging info, see docs, currently unused. """ images = [] labels = [] counter = 0 for item in self.data: image_data = item[0] label_data = item[1] if counter % self.modulo == 0 \ and len(images) < self.desired_number_images: pred = self.model.predict(image_data) mask = pred[0] mask = (mask > 0.5).astype(np.ubyte) * 255 images.append(mask) labels.append(label_data[0]) counter = counter + 1 if counter >= self.number_of_samples: break images_concatenated = np.concatenate(images, axis=1) labels_concatenated = np.concatenate(labels, axis=1) data = np.concatenate((images_concatenated, labels_concatenated), axis=0) self.save_to_tensorboard(data, epoch)
[docs] def save_to_tensorboard(self, npyfile, step): """ Write a set of images, in a format suitable for Tensorboard. :param npyfile: block of data, see above method. :param step: some int to indicate progress, e.g. batch number or epoch. """ #pylint:disable=not-context-manager image = np.reshape(npyfile, (-1, npyfile.shape[0], npyfile.shape[1], npyfile.shape[2])) writer = tf.summary.create_file_writer(self.tensor_board_dir) with writer.as_default(): tf.summary.image("Predicted (top), Labelled (bottom)", image, step=step)