# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""End-to-end tests that check model correctness."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import tempfile
import unittest

import numpy as np
import tensorflow as tf
from tensorflow.compat import v1 as tfv1

# pylint: disable=g-bad-import-order
from tfltransfer import bases
from tfltransfer import optimizers
from tfltransfer import heads
from tfltransfer import tflite_transfer_converter

# pylint: enable=g-bad-import-order

IMAGE_SIZE = 224
BATCH_SIZE = 128
NUM_CLASSES = 5
VALIDATION_SPLIT = 0.2
LEARNING_RATE = 0.001
BOTTLENECK_SHAPE = (7, 7, 1280)

DATASET_URL = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"


class TransferModel(object):
    """Test consumer of models generated by the converter."""

    def __init__(self, dataset_dir, base_model, head_model, optimizer):
        """Creates a wrapper for a set of models and a data set."""
        self.dataset_dir = dataset_dir

        datagen = tf.keras.preprocessing.image.ImageDataGenerator(
            rescale=1.0 / 255, validation_split=VALIDATION_SPLIT
        )
        self.train_img_generator = datagen.flow_from_directory(
            self.dataset_dir,
            target_size=(IMAGE_SIZE, IMAGE_SIZE),
            batch_size=BATCH_SIZE,
            subset="training",
        )
        self.val_img_generator = datagen.flow_from_directory(
            self.dataset_dir,
            target_size=(IMAGE_SIZE, IMAGE_SIZE),
            batch_size=BATCH_SIZE,
            subset="validation",
        )

        converter = tflite_transfer_converter.TFLiteTransferConverter(
            NUM_CLASSES, base_model, head_model, optimizer, BATCH_SIZE
        )
        models = converter._convert()
        self.initialize_model = models["initialize"]
        self.bottleneck_model = models["bottleneck"]
        self.train_head_model = models["train_head"]
        self.inference_model = models["inference"]
        self.optimizer_model = models["optimizer"]
        self.variables = self._generate_initial_variables()

        optim_state_shapes = self._optimizer_state_shapes()
        self.optim_state = [
            np.zeros(shape, dtype=np.float32) for shape in optim_state_shapes
        ]

    def _generate_initial_variables(self):
        """Generates the initial model variables."""
        interpreter = tf.lite.Interpreter(model_content=self.initialize_model)
        zero_in = interpreter.get_input_details()[0]
        variable_outs = interpreter.get_output_details()
        interpreter.allocate_tensors()
        interpreter.set_tensor(zero_in["index"], np.float32(0.0))
        interpreter.invoke()
        return [interpreter.get_tensor(var["index"]) for var in variable_outs]

    def _optimizer_state_shapes(self):
        """Reads the shapes of the optimizer parameters (mutable state)."""
        interpreter = tf.lite.Interpreter(model_content=self.optimizer_model)
        num_variables = len(self.variables)
        optim_state_inputs = interpreter.get_input_details()[num_variables * 2 :]
        return [input_["shape"] for input_ in optim_state_inputs]

    def prepare_bottlenecks(self):
        """Passes all images through the base model and save the bottlenecks.

        This method has to be called before any training or inference.
        """
        (
            self.train_bottlenecks,
            self.train_labels,
        ) = self._collect_and_generate_bottlenecks(self.train_img_generator)
        self.val_bottlenecks, self.val_labels = self._collect_and_generate_bottlenecks(
            self.val_img_generator
        )

    def _collect_and_generate_bottlenecks(self, image_gen):
        """Consumes a generator and converts all images to bottlenecks.

        Args:
          image_gen: A Keras data generator for images to process

        Returns:
          Two NumPy arrays: (bottlenecks, labels).
        """
        collected_bottlenecks = np.zeros(
            (image_gen.samples,) + BOTTLENECK_SHAPE, dtype=np.float32
        )
        collected_labels = np.zeros((image_gen.samples, NUM_CLASSES), dtype=np.float32)

        next_idx = 0
        for bottlenecks, truth in self._generate_bottlenecks(make_finite(image_gen)):
            batch_size = bottlenecks.shape[0]
            collected_bottlenecks[next_idx : next_idx + batch_size] = bottlenecks
            collected_labels[next_idx : next_idx + batch_size] = truth
            next_idx += batch_size

        return collected_bottlenecks, collected_labels

    def _generate_bottlenecks(self, image_gen):
        """Generator adapter that passes images through the bottleneck model.

        Args:
          image_gen: A generator that returns images to be processed. Images are
            paired with ground truth labels.

        Yields:
          Bottlenecks from input images, paired with ground truth labels.
        """
        interpreter = tf.lite.Interpreter(model_content=self.bottleneck_model)
        [x_in] = interpreter.get_input_details()
        [bottleneck_out] = interpreter.get_output_details()

        for (x, y) in image_gen:
            batch_size = x.shape[0]
            interpreter.resize_tensor_input(
                x_in["index"], (batch_size, IMAGE_SIZE, IMAGE_SIZE, 3)
            )
            interpreter.allocate_tensors()
            interpreter.set_tensor(x_in["index"], x)
            interpreter.invoke()
            bottleneck = interpreter.get_tensor(bottleneck_out["index"])
            yield bottleneck, y

    def train_head(self, num_epochs):
        """Trains the head model for a given number of epochs.

        SGD is used as an optimizer.

        Args:
          num_epochs: how many epochs should be trained

        Returns:
          A list of train_loss values after every epoch trained.

        Raises:
          RuntimeError: when prepare_bottlenecks() has not been called.
        """
        if not hasattr(self, "train_bottlenecks"):
            raise RuntimeError("prepare_bottlenecks has not been called")
        results = []
        for _ in range(num_epochs):
            loss = self._train_one_epoch(
                self._generate_batches(self.train_bottlenecks, self.train_labels)
            )
            results.append(loss)
        return results

    def _generate_batches(self, x, y):
        """Creates a generator that iterates over the data in batches."""
        num_total = x.shape[0]
        for begin in range(0, num_total, BATCH_SIZE):
            end = min(begin + BATCH_SIZE, num_total)
            yield x[begin:end], y[begin:end]

    def _train_one_epoch(self, train_gen):
        """Performs one training epoch."""
        interpreter = tf.lite.Interpreter(model_content=self.train_head_model)
        interpreter.allocate_tensors()
        x_in, y_in = interpreter.get_input_details()[:2]
        variable_ins = interpreter.get_input_details()[2:]
        loss_out = interpreter.get_output_details()[0]
        gradient_outs = interpreter.get_output_details()[1:]

        epoch_loss = 0.0
        num_processed = 0
        for bottlenecks, truth in train_gen:
            batch_size = bottlenecks.shape[0]
            if batch_size < BATCH_SIZE:
                bottlenecks = pad_batch(bottlenecks, BATCH_SIZE)
                truth = pad_batch(truth, BATCH_SIZE)

            interpreter.set_tensor(x_in["index"], bottlenecks)
            interpreter.set_tensor(y_in["index"], truth)
            for variable_in, variable_value in zip(variable_ins, self.variables):
                interpreter.set_tensor(variable_in["index"], variable_value)
            interpreter.invoke()

            loss = interpreter.get_tensor(loss_out["index"])
            gradients = [
                interpreter.get_tensor(gradient_out["index"])
                for gradient_out in gradient_outs
            ]

            self._apply_gradients(gradients)
            epoch_loss += loss * batch_size
            num_processed += batch_size

        epoch_loss /= num_processed
        return epoch_loss

    def _apply_gradients(self, gradients):
        """Applies the optimizer to the model parameters."""
        interpreter = tf.lite.Interpreter(model_content=self.optimizer_model)
        interpreter.allocate_tensors()
        num_variables = len(self.variables)
        variable_ins = interpreter.get_input_details()[:num_variables]
        gradient_ins = interpreter.get_input_details()[
            num_variables : num_variables * 2
        ]
        state_ins = interpreter.get_input_details()[num_variables * 2 :]
        variable_outs = interpreter.get_output_details()[:num_variables]
        state_outs = interpreter.get_output_details()[num_variables:]

        for variable, gradient, variable_in, gradient_in in zip(
            self.variables, gradients, variable_ins, gradient_ins
        ):
            interpreter.set_tensor(variable_in["index"], variable)
            interpreter.set_tensor(gradient_in["index"], gradient)

        for optim_state_elem, state_in in zip(self.optim_state, state_ins):
            interpreter.set_tensor(state_in["index"], optim_state_elem)

        interpreter.invoke()
        self.variables = [
            interpreter.get_tensor(variable_out["index"])
            for variable_out in variable_outs
        ]
        self.optim_state = [
            interpreter.get_tensor(state_out["index"]) for state_out in state_outs
        ]

    def measure_inference_accuracy(self):
        """Runs the inference model and measures accuracy on the validation
        set."""
        interpreter = tf.lite.Interpreter(model_content=self.inference_model)
        bottleneck_in = interpreter.get_input_details()[0]
        variable_ins = interpreter.get_input_details()[1:]
        [y_out] = interpreter.get_output_details()

        inference_accuracy = 0.0
        num_processed = 0
        for bottleneck, truth in self._generate_batches(
            self.val_bottlenecks, self.val_labels
        ):
            batch_size = bottleneck.shape[0]
            interpreter.resize_tensor_input(
                bottleneck_in["index"], (batch_size,) + BOTTLENECK_SHAPE
            )
            interpreter.allocate_tensors()

            interpreter.set_tensor(bottleneck_in["index"], bottleneck)
            for variable_in, variable_value in zip(variable_ins, self.variables):
                interpreter.set_tensor(variable_in["index"], variable_value)
            interpreter.invoke()

            preds = interpreter.get_tensor(y_out["index"])

            acc = (
                np.argmax(preds, axis=1) == np.argmax(truth, axis=1)
            ).sum() / batch_size
            inference_accuracy += acc * batch_size
            num_processed += batch_size

        inference_accuracy /= num_processed
        return inference_accuracy


def make_finite(data_gen):
    """An adapter for Keras data generators that makes them finite.

    The default behavior in Keras is to keep looping infinitely through
    the data.

    Args:
      data_gen: An infinite Keras data generator.

    Yields:
      Same values as the parameter generator.
    """
    num_samples = data_gen.samples
    num_processed = 0
    for batch in data_gen:
        batch_size = batch[0].shape[0]
        if batch_size + num_processed > num_samples:
            batch_size = num_samples - num_processed
            should_stop = True
        else:
            should_stop = False
        if batch_size == 0:
            return

        batch = tuple(x[:batch_size] for x in batch)
        yield batch
        num_processed += batch_size
        if should_stop:
            return


# TODO(b/135138207) investigate if we can get rid of this.
def pad_batch(batch, batch_size):
    """Resize batch to a given size, tiling present samples over missing.

    Example:
      Suppose batch_size is 5, batch is [1, 2].
      Then the return value is [1, 2, 1, 2, 1].

    Args:
      batch: An ndarray with first dimension size <= batch_size.
      batch_size: Desired size for first dimension.

    Returns:
      An ndarray of the same shape, except first dimension has
      the desired size.
    """
    padded = np.zeros((batch_size,) + batch.shape[1:], dtype=batch.dtype)
    next_idx = 0
    while next_idx < batch_size:
        fill_len = min(batch.shape[0], batch_size - next_idx)
        padded[next_idx : next_idx + fill_len] = batch[:fill_len]
        next_idx += fill_len
    return padded


class ModelCorrectnessTest(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        super(ModelCorrectnessTest, cls).setUpClass()
        zip_file = tf.keras.utils.get_file(
            origin=DATASET_URL, fname="flower_photos.tgz", extract=True
        )
        cls.dataset_dir = os.path.join(os.path.dirname(zip_file), "flower_photos")

        mobilenet_dir = tempfile.mkdtemp("tflite-transfer-test")
        mobilenet_keras = tf.keras.applications.MobileNetV2(
            input_shape=(IMAGE_SIZE, IMAGE_SIZE, 3),
            include_top=False,
            weights="imagenet",
        )
        tfv1.keras.experimental.export_saved_model(mobilenet_keras, mobilenet_dir)
        cls.mobilenet_dir = mobilenet_dir

    def setUp(self):
        super(ModelCorrectnessTest, self).setUp()
        self.mobilenet_dir = ModelCorrectnessTest.mobilenet_dir
        self.dataset_dir = ModelCorrectnessTest.dataset_dir

    def test_mobilenet_v2_saved_model_and_softmax_classifier(self):
        base_model = bases.SavedModelBase(self.mobilenet_dir)
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES
        )
        optimizer = optimizers.SGD(LEARNING_RATE)
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def test_mobilenet_v2_saved_model_quantized_and_softmax_classifier(self):
        base_model = bases.SavedModelBase(self.mobilenet_dir, quantize=True)
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES
        )
        optimizer = optimizers.SGD(LEARNING_RATE)
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def test_mobilenet_v2_base_and_softmax_classifier(self):
        base_model = bases.MobileNetV2Base()
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES
        )
        optimizer = optimizers.SGD(LEARNING_RATE)
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def test_mobilenet_v2_base_and_softmax_classifier_l2(self):
        base_model = bases.MobileNetV2Base()
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES, l2_reg=0.1
        )
        optimizer = optimizers.SGD(LEARNING_RATE)
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def test_mobilenet_v2_base_quantized_and_softmax_classifier(self):
        base_model = bases.MobileNetV2Base(quantize=True)
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES
        )
        optimizer = optimizers.SGD(LEARNING_RATE)
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def test_mobilenet_v2_base_and_softmax_classifier_adam(self):
        base_model = bases.MobileNetV2Base()
        head_model = heads.SoftmaxClassifierHead(
            BATCH_SIZE, BOTTLENECK_SHAPE, NUM_CLASSES
        )
        optimizer = optimizers.Adam()
        model = TransferModel(self.dataset_dir, base_model, head_model, optimizer)
        self.assertModelAchievesAccuracy(model, 0.80)

    def assertModelAchievesAccuracy(self, model, target_accuracy, num_epochs=30):
        model.prepare_bottlenecks()
        print("Bottlenecks prepared")
        history = model.train_head(num_epochs)
        print("Training completed, history = {}".format(history))
        accuracy = model.measure_inference_accuracy()
        print("Final accuracy = {:.2f}".format(accuracy))
        self.assertGreater(accuracy, target_accuracy)


if __name__ == "__main__":
    unittest.main()
