# Lint as: python3
# Copyright 2018 The SPL Authors.
#
# All rights reserved.
#
# This is the code for reproducing results of the paper.
"""Base code for training of all experiments."""

import abc
import copy
import json
import math
import os
import time

from absl import flags
from absl import logging
from easydict import EasyDict
import gin
import numpy as np

import tensorflow.compat.v2 as tf


flags.DEFINE_string(
    'model_dir', None,
    ('The directory where the model weights and training/evaluation summaries '
     'are stored.'))
flags.mark_flag_as_required('model_dir')
flags.DEFINE_boolean('save_summaries_with_epoch', True,
                     'If true then summaries will be reported with epoch*1000 '
                     'as x-coordinate, otherwise current step will be used as '
                     'x-coordinate.')
flags.DEFINE_integer(
    'steps_per_run', None,
    'Number of steps per one run of training op. Set it only for debugging.')
flags.DEFINE_string('hparams', '', 'JSON with list of hyperparameters.')
flags.DEFINE_multi_string('gin_bindings', [],
                          'Newline separated list of Gin parameter bindings.')
flags.DEFINE_multi_string('gin_config', ['config.gin'],
                          'List of paths to the config files.')


FLAGS = flags.FLAGS


metrics = tf.keras.metrics

flags.DEFINE_enum(
    'train_mode', 'pretrain', ['pretrain', 'finetune', 'diagnose'],
    'The train mode controls different objectives and trainable components.')
flags.DEFINE_string(
    'ft_ckpt_dir', None,
    'Finetune checkpoint saved folder.')


DEFAULT_COMMON_HPARAMS = EasyDict({
    'bfloat16':
        True,
    'num_epochs':
        300,
    'use_ema':
        True,
    'ema_decay':
        0.999,  # Reasonable default, higher numbers does not work
    'weight_decay':
        0.0003,
    'per_worker_batch_size':
        128,
    'per_worker_eval_batch_size':
        128,
    'dataset':
        None,
    'network':
        'keras_resnet50',  # set arch
    'optimizer':
        'sgd',  # set optimizer
    'finetune':
        EasyDict({  # finetune hyparameters
            'ckpt_subdir': '',
            'ckpt': '',
            'after_layer': 'fc',
            'tasks': '',  # evaluate on multiple datasets [dataset1,dataset2]
        }),
    'arch':
        EasyDict({
            'normalization': 'batch_norm',
        }),
    'input':
        EasyDict({
            'saturate_uint8': True,
            'scale_and_center': True,
            'use_default_augment': True,
            'addition_datasets': None
        }),
    'learning_rate':
        EasyDict({
            'schedule_type': 'step',  # 'step', 'cosine_pi', 'cosine_half_pi'
            'base_lr': 0.1,
            'use_warmup': True,
            'warmup_epochs': 5,
            'decay_rate': 0.1,
            'decay_epochs': 50,
            'distributed_mode': True,  # whether distributed training.
        }),
    'augment':
        EasyDict({
            'type': 'noop',
        }),
})

# When create new gins_xxx funcs, add the fn to GAIN_BINDDINGS_FUNS.
GAIN_BINDDINGS_FUNS = []


def gin_fun_register():
  def wrap(f):
    GAIN_BINDDINGS_FUNS.append(f)
    return f
  return wrap


@gin_fun_register()
@gin.configurable('hparams')
def gin_hparams(
    dataset=None,  # required input from gin
    bfloat16=True,
    num_epochs=300,
    use_ema=True,
    network='keras_resnet50',
    optimizer='sgd',
    ema_decay=0.999,
    weight_decay=0.0003,
    per_worker_batch_size=128,
    per_worker_eval_batch_size=128,
    ft_ckpt_subdir='',
    eval_only=False):
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS[k] = _hp[k]


@gin_fun_register()
@gin.configurable('hparams.arch')
def gin_arch(normalization='batch_norm'):
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS['arch'][k] = _hp[k]


@gin_fun_register()
@gin.configurable('hparams.input')
def gin_input(saturate_uint8=True,
              scale_and_center=True,
              use_default_augment=True,
              addition_datasets=None):
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS['input'][k] = _hp[k]


@gin_fun_register()
@gin.configurable('hparams.learning_rate')
def gin_learning_rate(schedule_type='step',
                      decay_rate=0.1,
                      use_warmup=True,
                      decay_epochs=50,
                      base_lr=0.1,
                      warmup_epochs=5,
                      distributed_mode=True):
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      if k == 'decay_epochs' and isinstance(decay_epochs, (str,)):
        DEFAULT_COMMON_HPARAMS['learning_rate'][k] = [
            int(a) for a in _hp[k].split(',')
        ]
      else:
        DEFAULT_COMMON_HPARAMS['learning_rate'][k] = _hp[k]


@gin_fun_register()
@gin.configurable('hparams.augment')
def gin_augment(type='noop'): 
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS['augment'][k] = locals()[k]


@gin_fun_register()
@gin.configurable('hparams.finetune')
def gin_finetune(ckpt_subdir='', ckpt='', after_layer='fc', tasks=''):
  _hp = dict(locals())
  for k in _hp:
    if not k.startswith('_'):
      DEFAULT_COMMON_HPARAMS['finetune'][k] = _hp[k]



def update_dict(dict_to_update, new_values):
  for k, v in new_values.items():
    if isinstance(v, dict) and (k in dict_to_update):
      update_dict(dict_to_update[k], v)
    else:
      dict_to_update[k] = v


def get_hparams(default_model_hparams=None):
  """Returns dictionary with all hyperparameters.

  Args:
    default_model_hparams: dictionary with default model-specific
      hyperparameters.

  Returns:
    dictionary with all parsed hyperparameters.

  This function parses value of the --hparams flag as JSON and returns
  dictionary with all hyperparameters. Note that default values of hyperparams
  are takes from DEFAULT_COMMON_HPARAMS constant and from optional
  default_model_hparams argument.
  """

  # using gin binded func to overwrite all
  for g in GAIN_BINDDINGS_FUNS:
    g()
  assert not FLAGS.hparams, 'Deprecated FLAGS.hparams'

  hparams_str = FLAGS.hparams.strip()
  if not hparams_str.startswith('{'):
    hparams_str = '{ ' + hparams_str + ' }'
  hparams = copy.deepcopy(DEFAULT_COMMON_HPARAMS)

  if default_model_hparams:
    update_dict(hparams, default_model_hparams)
  update_dict(hparams, json.loads(hparams_str))

  return hparams


def safe_mean(losses):
  total = tf.reduce_sum(losses)
  num_elements = tf.dtypes.cast(tf.size(losses), dtype=losses.dtype)
  return tf.math.divide_no_nan(total, num_elements)


def create_distribution_strategy():
  """Creates distribution strategy."""

  logging.info('Using MirroredStrategy on local devices.')
  distribution_strategy = tf.distribute.MirroredStrategy()
  logging.info('Created distribution strategy: %s', distribution_strategy)

  return distribution_strategy


class MetricManager(object):
  """Metric management for TensorBoard.

  Use example:
    test_metrics = MetricManager()

    test_metrics.add(metrics.Mean('test/loss', dtype=tf.float32))

    test_metrics['test/loss'].update_states(loss)

    # call to write Tensorboard
    metric_txt = test_metrics.summary_metrics()

    print(metric_txt)

    test_metrics.rest_states()  # call after all data evaluated

  """

  def __init__(self):
    self._metrics = dict()
    self._step = 0
    self._last_step = 0
    self.work_unit = None

  def __getitem__(self, name):
    return self._metrics[name]

  def _to_float(self, value):
    """To float type if value has size equal to 1."""
    round_float = lambda x: round(float(x), 3)
    try:
      value = float(value)
    except Exception as e:
      pass

    if isinstance(value, (float, int)):
      return round_float(value)
    elif isinstance(value, (np.ndarray,)):
      if value.rank <= 1:
        return round_float(value)
      else:
        return ''
    else:
      tensor = value.result()
      if tensor.shape.rank == 0:
        return round_float(tensor.numpy())
      else:
        return ''

  def _save_metrics(self, save_metrics, step):
    """Internal func to communicate Tensorboard."""
    for k, v in save_metrics.items():
      if isinstance(v, tf.Tensor) and v.shape.rank > 0:
        if v.shape.rank >= 2:
          # image
          img = v / (tf.reduce_max(v) + 1e-6)
          img = tf.expand_dims(img, -1)  # insert color channel
          while img.shape.rank < 4:
            img = tf.expand_dims(img, 0)
          tf.summary.image(k, img, step=step)
        else:
          # histogram
          hist_elements = []
          for idx, val in enumerate(v.numpy()):
            count = int(val * 100)
            if count > 0:
              hist_elements.extend([idx] * count)
          tf.summary.histogram(k + '_h', hist_elements, step=step)
      elif isinstance(v, tf.keras.metrics.MeanTensor):
        tf.summary.histogram(k + '_h', v.result(), step=step)
      else:
        if not isinstance(v, (float, int, tf.Tensor)):
          try:
            v = v.result().numpy()
          except Exception as e:
            raise ValueError(f'Metric {k} has errors to access. ')
        tf.summary.scalar(k, v, step=step)
        # if scalar and work_unit exists, write to it
        if self.work_unit is not None:
          if k not in self.work_measurements:
            self.work_measurements[k] = self.work_unit.get_measurement_series(
                label=k)
          self.work_measurements[k].create_measurement(
              objective_value=v, step=step)

  def set_notes(self, message):
    """Set notes."""
    if self.work_unit is not None:
      self.work_unit.set_notes(message)

  def clear(self):
    self._metrics = dict()
    self.work_measurements = dict()

  def add(self, obj):
    self._metrics[obj.name] = obj
    if self.work_unit is not None and not isinstance(
        obj, tf.keras.metrics.MeanTensor):
      self.work_measurements[obj.name] = self.work_unit.get_measurement_series(
          label=obj.name)

  def reset_states(self):
    """Reset all states."""
    for _, v in self._metrics.items():
      v.reset_states()
    self._last_step = self._to_float(self._step)

  def numpy(self, name):
    """Obtain the float value (or numpy if an array) of a metric."""
    assert name in self._metrics.keys()
    return self._to_float(self._metrics[name])

  def summary_metrics(self, step, extra_metrics=None):
    """Summarize to Tensorboard."""
    self._step = step
    all_metrics = {}
    all_metrics.update(self._metrics)
    if extra_metrics:
      all_metrics.update(extra_metrics)
    self._save_metrics(all_metrics, step)

    return self.summary_metrics_to_string(all_metrics)

  def summary_metrics_to_string(self, all_metrics=None):
    """Summarize to a string for printing."""
    string = ['Step:{}'.format(self._to_float(self._step))]
    if all_metrics is None:
      all_metrics = self._metrics
    for k, v in all_metrics.items():
      v = self._to_float(v)
      if v:
        string.append('{}:{:.3f}'.format(k, v))
    string = '\n'.join(string)
    return string


class Experiment(object):
  """Helper class with most training routines."""

  def __init__(self,
               distribution_strategy,
               hparams):
    self.hparams = hparams
    self.strategy = distribution_strategy
    self.model_dir = FLAGS.model_dir
    self.dataset = hparams.dataset
    num_workers = len(distribution_strategy.extended.worker_devices)
    self.batch_size = self.hparams.per_worker_batch_size * num_workers
    self.eval_batch_size = self.hparams.per_worker_eval_batch_size * num_workers

    self.clean_model_dir()
    self.save_hparams()
    logging.info('Saving checkpoints at %s', self.model_dir)
    logging.info('Hyper parameters: %s', self.hparams)

    # set train and test metric manager.
    self.train_metrics = MetricManager()
    self.test_metrics = MetricManager()
    self._step_offset = 0

  @abc.abstractmethod
  def create_dataset(self):
    """Creates dataset.

    Returns:
      datasets: datasets.Datasets named tuple.
      augmenter_state: state of the stateful augmenter.
    """
    pass

  @abc.abstractmethod
  def create_model(self):
    """Creates model and everything needed to train it.

    Returns:
      checkpointed_data: dictionary with model data which needs to be saved to
        checkpoints.
    """
    pass

  def clean_model_dir(self):
    if not tf.io.gfile.exists(self.model_dir):
      logging.warning('Create checkpoint dir: {}'.format(self.model_dir))
      tf.io.gfile.makedirs(self.model_dir)
    else:
      logging.warning('checkpoint dir exists: {}'.format(self.model_dir))

  def create_or_load_checkpoint(self, **kwargs):
    """Creates and maybe loads checkpoint."""
    checkpoint = tf.train.Checkpoint(**kwargs)

    latest_checkpoint = tf.train.latest_checkpoint(self.model_dir)
    if latest_checkpoint:
      # checkpoint.restore must be within a strategy.scope() so that optimizer
      # slot variables are mirrored.
      checkpoint.restore(latest_checkpoint)
      logging.info('Loaded checkpoint %s', latest_checkpoint)
    return checkpoint

  def save_hparams(self):
    """Saves hyperparameters as a json file."""
    filename = os.path.join(self.model_dir, 'hparams.json')
    if not tf.io.gfile.exists(filename):
      with tf.io.gfile.GFile(filename, 'w') as f:
        json.dump(self.hparams, f, indent=2)

  @abc.abstractmethod
  def train_step(self, iterator, num_steps_to_run):
    """Training StepFn."""
    pass

  @tf.function
  def test_step(self, iterator, num_steps_to_run):
    """Evaluation StepFn."""

    def step_fn(inputs):
      """Per-Replica evaluation step function."""
      images, labels = inputs['image'], inputs['label']
      logits = self.net(images, is_training=False)
      logits = tf.cast(logits, tf.float32)
      loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
          labels=labels, logits=logits)
      loss = safe_mean(loss)
      self.test_metrics['test/loss'].update_state(loss)
      self.test_metrics['test/accuracy'].update_state(labels, logits)
      self.test_metrics['test/accuracy_top5'].update_state(labels, logits)
      if self.hparams.use_ema:
        ema_logits = self.net_ema(images, is_training=False)
        self.test_metrics['test/ema_accuracy'].update_state(labels, ema_logits)
        self.test_metrics['test/ema_accuracy_top5'].update_state(
            labels, ema_logits)

    for _ in tf.range(num_steps_to_run):
      self.strategy.experimental_run_v2(step_fn, args=(next(iterator),))

  @abc.abstractmethod
  def save_train_metrics(self):
    """Saves and resets all training metrics."""
    step = self.get_current_train_step()
    extra_metrics = {
        'train/learning_rate':
            self.optimizer.learning_rate(self.optimizer.iterations)
    }
    self.train_metrics.summary_metrics(step, extra_metrics)
    train_loss = self.train_metrics['train/total_loss'].result().numpy()
    train_accuracy = self.train_metrics['train/accuracy'].result().numpy()
    notes = 'Training loss: %s, accuracy: %s%% at epoch %d' % (round(
        train_loss, 4), round(train_accuracy * 100,
                              2), step//self.datasets.steps_per_epoch)
    logging.info(notes)
    self.train_metrics.set_notes(notes)
    self.train_metrics.reset_states()

  @abc.abstractmethod
  def get_current_train_step(self):
    """Returns current training step."""
    pass

  @abc.abstractmethod
  def after_epoch(self):
    pass

  @abc.abstractmethod
  def before_training(self):
    pass

  def train_and_eval(self):
    """Runs training loop with periodic evaluation."""

    with self.strategy.scope():

      # Create datasets
      self.datasets, self.augmenter_state = self.create_dataset()
      self.total_train_steps = int(math.ceil(
          self.datasets.steps_per_epoch * self.hparams.num_epochs))
      if FLAGS.steps_per_run is None:
        steps_per_run = int(self.datasets.steps_per_epoch)
      else:
        steps_per_run = FLAGS.steps_per_run
      # Create model
      checkpointed_data = self.create_model()
      checkpoint = self.create_or_load_checkpoint(**checkpointed_data)

      # Create eval metrics
      self.test_metrics.add(metrics.Mean('test/loss', dtype=tf.float32))
      self.test_metrics.add(metrics.SparseCategoricalAccuracy(
          'test/accuracy', dtype=tf.float32))
      self.test_metrics.add(metrics.SparseTopKCategoricalAccuracy(
          k=5, name='test/accuracy_top5', dtype=tf.float32))
      self.test_metrics.add(metrics.SparseCategoricalAccuracy(
          'test/ema_accuracy', dtype=tf.float32))
      self.test_metrics.add(metrics.SparseTopKCategoricalAccuracy(
          k=5, name='test/ema_accuracy_top5', dtype=tf.float32))

      self.summary_writer = tf.summary.create_file_writer(
          os.path.join(self.model_dir, 'summaries'))

      self.before_training()

      # training loop
      train_iterator = iter(self.datasets.train_dataset)
      steps_per_second = 0.0
      initial_step = self.get_current_train_step()

      for next_step_to_run in range(initial_step,
                                    self.total_train_steps,
                                    steps_per_run):
        logging.info('Running steps %d - %d (total %d)',
                     next_step_to_run + 1,
                     next_step_to_run + steps_per_run,
                     self.total_train_steps)
        start_time = time.time()
        self.train_step(train_iterator, tf.constant(steps_per_run))
        self.after_epoch()
        with self.summary_writer.as_default():
          self.save_train_metrics()

          test_iterator = iter(self.datasets.eval_dataset)
          self.test_step(test_iterator,
                         tf.constant(self.datasets.steps_per_eval))
          step = self.get_current_train_step()
          extra_metrics = {
              'train/steps_per_second':
                  steps_per_second,
              'train/cur_epoch': (step / self.datasets.steps_per_epoch),
              'train/learning_rate':
                  self.optimizer.learning_rate(self.optimizer.iterations)
          }
          summary_string = self.test_metrics.summary_metrics(
              step, extra_metrics=extra_metrics)
        logging.info(summary_string)
        self.test_metrics.reset_states()

        checkpoint_name = checkpoint.save(
            os.path.join(self.model_dir, 'checkpoint'))
        logging.info('Saved checkpoint to %s', checkpoint_name)

        end_time = time.time()
        steps_per_second = steps_per_run / (end_time - start_time)

  def evalulation(self):
    """Runs evaluation."""

    with self.strategy.scope():

      # Create datasets
      self.datasets, _ = self.create_dataset()
      # Create model
      checkpointed_data = self.create_model()
      _ = self.create_or_load_checkpoint(**checkpointed_data)

      # Create eval metrics
      self.test_metrics.add(metrics.Mean('test/loss', dtype=tf.float32))
      self.test_metrics.add(
          metrics.SparseCategoricalAccuracy('test/accuracy', dtype=tf.float32))
      self.test_metrics.add(
          metrics.SparseTopKCategoricalAccuracy(
              k=5, name='test/accuracy_top5', dtype=tf.float32))
      self.test_metrics.add(
          metrics.SparseCategoricalAccuracy(
              'test/ema_accuracy', dtype=tf.float32))
      self.test_metrics.add(
          metrics.SparseTopKCategoricalAccuracy(
              k=5, name='test/ema_accuracy_top5', dtype=tf.float32))

      test_iterator = iter(self.datasets.eval_dataset)
      self.test_step(test_iterator, tf.constant(self.datasets.steps_per_eval))
      step = self.get_current_train_step()
      summary_string = self.test_metrics.summary_metrics(
          step, extra_metrics=None)
      logging.info(summary_string)
      self.test_metrics.reset_states()
