# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Main method training code."""

import functools
from absl import app
from absl import flags
from absl import logging
import beam_inference_run as beam_inf

import gin
import tensorflow.compat.v2 as tf

import training
from augment import create_augmenter
from datasets import datasets
from networks.networks import make_network
from utils import ema
from utils.learning_rate import make_learning_rate_schedule
from training import gin_fun_register
from training import DEFAULT_COMMON_HPARAMS
from datasets.weaklabel_datasets import preprocess_image

# Beam related parameters
flags.DEFINE_boolean('beam_inference', False, 'Process sstable.')
flags.DEFINE_string('checkpoint', None, 'Checkpoint pattern.')

FLAGS = flags.FLAGS
metrics = tf.keras.metrics


# Adding default parameters for mtct
@gin_fun_register()
@gin.configurable('hparams.spl')
def gin_hparams(imagenet_weight=False):  # cross-replica loss
  _hp = dict(locals())  
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS[k] = _hp[k]



class SupervisedExperiment(training.Experiment):
  """Supervised training object."""

  def create_dataset(self):
    """Creates dataset."""
    augmenter_state, augmenter = create_augmenter(**self.hparams.augment)
    train_eval_datasets = datasets.make_train_eval_datasets(
        self.strategy,
        self.dataset,
        self.batch_size,
        augmenter, {
            'use_bfloat16': self.hparams.bfloat16,
            'saturate_uint8': self.hparams.input.saturate_uint8,
            'scale_and_center': self.hparams.input.scale_and_center,
            'use_default_augment': self.hparams.input.use_default_augment,
        },
        eval_batch_size=self.eval_batch_size)
    return train_eval_datasets, augmenter_state

  def create_or_load_checkpoint(self, **kwargs):
    """Creates and maybe loads checkpoint."""
    checkpoint = tf.train.Checkpoint(**kwargs)
    latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint).expect_partial()
      logging.info('Loaded checkpoint %s', latest_checkpoint)
    else:
      if self.hparams.finetune.ckpt:
        del kwargs['optimizer']
        if not self.hparams.use_ema and 'model_ema' in kwargs:
          del kwargs['model_ema']
        local_checkpoint = tf.train.Checkpoint(**kwargs)
        local_checkpoint.restore(self.hparams.finetune.ckpt).expect_partial()
        self.net.recreate_classifier()
        if self.hparams.use_ema:
          self.net_ema.recreate_classifier()
        logging.info('Loaded finetune checkpoint %s',
                     self.hparams.finetune.ckpt)
        if self.hparams.use_ema:
          # borrow this function to assign loaded ema net to net
          ema.assign_ema_vars_from_initial_values(self.net.variables,
                                                  self.net_ema.variables)
    return checkpoint

  def create_model(self):
    """Creates model."""
    logging.info('Building model')
    if self.hparams.finetune.ckpt:
      # do no create dense layer for loading pre-trained backbone,
      # which has different dense layer shape
      self.hparams.arch.skip_dense = True
    self.hparams.arch.imagenet_weight = self.hparams.imagenet_weight
    # Create network
    self.net = make_network(self.hparams, self.datasets.num_classes,
                            self.datasets.batch_shape)
    if self.hparams.use_ema:
      self.net_ema = make_network(self.hparams, self.datasets.num_classes,
                                  self.datasets.batch_shape)
      ema.assign_ema_vars_from_initial_values(self.net_ema.variables,
                                              self.net.variables)
    else:
      self.net_ema = None
    self.net.summary()

    # Create optimizer
    self.optimizer = tf.keras.optimizers.SGD(
        learning_rate=make_learning_rate_schedule(self.batch_size,
                                                  self.datasets.steps_per_epoch,
                                                  self.hparams.num_epochs,
                                                  self.hparams.learning_rate),
        momentum=0.9,
        nesterov=True)
    logging.info('Finished building model')

    # Create training metrics
    self.train_metrics.add(
        tf.keras.metrics.Mean('train/total_loss', dtype=tf.float32))
    self.train_metrics.add(
        tf.keras.metrics.SparseCategoricalAccuracy(
            'train/accuracy', dtype=tf.float32))
    self.train_metrics.add(tf.keras.metrics.Mean('loss/xe', dtype=tf.float32))
    self.train_metrics.add(tf.keras.metrics.Mean('loss/wd', dtype=tf.float32))

    # Prepare checkpointed data
    checkpointed_data = {
        'model': self.net,
        'optimizer': self.optimizer,
    }
    if self.hparams.use_ema:
      checkpointed_data['model_ema'] = self.net_ema
    return checkpointed_data

  @tf.function
  def train_step(self, iterator, num_steps_to_run):
    """Training StepFn."""

    def step_fn(inputs):
      """Per-Replica training step function."""
      images, labels = inputs['image'], inputs['label']
      with tf.GradientTape() as tape:
        logits = self.net(images, is_training=True)
        logits = tf.cast(logits, tf.float32)
        # Loss calculations.
        #
        # Part 1: Prediction loss.
        loss_xe = tf.nn.sparse_softmax_cross_entropy_with_logits(
            labels=labels, logits=logits)
        if 'weight' in inputs:  # if weighted loss
          weights = inputs['weight']
          loss_xe = tf.reduce_sum(loss_xe * weights)
        else:
          loss_xe = tf.reduce_mean(loss_xe)
        # Part 2: Model weights regularization
        if self.hparams.weight_decay > 0:
          loss_wd = self.hparams.weight_decay * tf.add_n([
              tf.reduce_sum(tf.square(v))
              for v in self.net.regularized_variables
          ])
        # Scale the loss given the TPUStrategy will reduce sum all gradients.
        loss = loss_xe + loss_wd
        scaled_loss = loss / self.strategy.num_replicas_in_sync

      grads = tape.gradient(scaled_loss, self.net.trainable_variables)
      self.optimizer.apply_gradients(zip(grads, self.net.trainable_variables))
      if self.hparams.use_ema:
        ema.update_ema_variables(self.net_ema.variables, self.net.variables,
                                 self.hparams.ema_decay)
      self.train_metrics['train/total_loss'].update_state(loss)
      self.train_metrics['loss/xe'].update_state(loss_xe)
      self.train_metrics['loss/wd'].update_state(loss_wd)
      self.train_metrics['train/accuracy'].update_state(labels, logits)

    for _ in tf.range(num_steps_to_run):
      self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))

  def get_current_train_step(self):
    """Returns current training step."""
    return self.optimizer.iterations.numpy()


def main(unused_argv):

  logging.info('Gin config: %s\nGin bindings: %s', FLAGS.gin_config,
               FLAGS.gin_bindings)
  gin.parse_config_files_and_bindings(FLAGS.gin_config, FLAGS.gin_bindings)
  tf.enable_v2_behavior()
  experiment = SupervisedExperiment(
      distribution_strategy=training.create_distribution_strategy(),
      hparams=training.get_hparams())

  if FLAGS.beam_inference:
    experiment.datasets, experiment.augmenter_state = experiment.create_dataset(
    )
    preprocess_image_fn = functools.partial(
        preprocess_image,
        is_training=False,
        augmentation=None,
        use_bfloat16=experiment.hparams.bfloat16,
        saturate_uint8=experiment.hparams.input.saturate_uint8,
        scale_and_center=experiment.hparams.input.scale_and_center,
        use_default_augment=experiment.hparams.input.use_default_augment)

    # using classical python func is not working for beam, use functools.partial
    make_network_fn = functools.partial(
        make_network,
        hparams=experiment.hparams,
        num_classes=experiment.datasets.num_classes,
        input_shape=experiment.datasets.batch_shape)
    checkpoint_model_name = 'model'
    beam_inf.run(
        preprocess_image_fn,
        make_network_fn,
        num_classes=experiment.datasets.num_classes,
        checkpoint_model_name=checkpoint_model_name)
  elif experiment.hparams.eval_only:
    experiment.evalulation()
  else:
    experiment.train_and_eval()


if __name__ == '__main__':

  logging.set_verbosity(logging.INFO)
  app.run(main)
