# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file

import gc
import io
import os
import time

import numpy as np
import tensorflow as tf
import tensorflow_gan as tfgan
import logging
# Keep the import below for registering all model definitions
from models import ddpm, ncsnv2, ncsnpp, unet_classifier
import losses_classifier as losses
import sampling
from models import utils as mutils
import datasets
import evaluation
import likelihood
from models.ema import ExponentialMovingAverage
import sde_lib
from absl import flags
import torch
from torch.utils import tensorboard
from torchvision.utils import make_grid, save_image
from utils import save_checkpoint, restore_checkpoint

FLAGS = flags.FLAGS


def train(config, workdir):
  """Runs the training pipeline.
  Args:
    config: Configuration to use.
    workdir: Working directory for checkpoints and TF summaries. If this
      contains checkpoint training will be resumed from the latest checkpoint.
  """

  # Create directories for experimental logs
  sample_dir = os.path.join(workdir, "samples")
  tf.io.gfile.makedirs(sample_dir)

  tb_dir = os.path.join(workdir, "tensorboard")
  tf.io.gfile.makedirs(tb_dir)
  writer = tensorboard.SummaryWriter(tb_dir)

  # Initialize model.
  class_model = mutils.create_model(config)
  ema = ExponentialMovingAverage(class_model.parameters(), decay=config.model.ema_rate)
  optimizer = losses.get_optimizer(config, class_model.parameters())
  state = dict(optimizer=optimizer, model=class_model, ema=ema, step=0)

  score_model = None
  use_dlsm = False
  dlsm_lambda = 0.
  weighting_dlsm = 0.
  if hasattr(config.model, 'use_dlsm') and config.model.use_dlsm:
    use_dlsm = True
    config.model.name = config.model.score
    dlsm_lambda = config.model.dlsm_lambda
    weighting_dlsm = config.model.weighting_dlsm
    score_model = mutils.create_model(config)
    ema_s = ExponentialMovingAverage(score_model.parameters(), decay=config.model.ema_rate)
    optimizer_s = losses.get_optimizer(config, score_model.parameters())
    state_s = dict(optimizer=optimizer_s, model=score_model, ema=ema_s, step=0)
    restore_checkpoint(config.model.score_path, state_s, config.device)
    score_model.eval()

  # Create checkpoints directory
  checkpoint_dir = os.path.join(workdir, "checkpoints")
  # Intermediate checkpoints to resume training after pre-emption in cloud environments
  checkpoint_meta_dir = os.path.join(workdir, "checkpoints-meta", "checkpoint.pth")
  tf.io.gfile.makedirs(checkpoint_dir)
  tf.io.gfile.makedirs(os.path.dirname(checkpoint_meta_dir))
  # Resume training when intermediate checkpoints are detected
  state = restore_checkpoint(checkpoint_meta_dir, state, config.device)
  initial_step = int(state['step'])

  # Build data iterators
  if 'SEMI' not in config.data.dataset:
    train_ds, eval_ds, _ = datasets.get_dataset(config,
                                              uniform_dequantization=config.data.uniform_dequantization)
  else:
    train_ds, train_unlabeled_ds, eval_ds = datasets.get_dataset(config,
                                              uniform_dequantization=config.data.uniform_dequantization)
  if hasattr(config.data, 'mix') and config.data.mix:
    print('using unlabeled data')
    train_labeled_iter = iter(train_ds)
    train_unlabeled_iter = iter(train_unlabeled_ds)
  else:
    train_iter = iter(train_ds)
  eval_iter = iter(eval_ds)  # pytype: disable=wrong-arg-types
  # for i in range(50000):
  #   next_train = next(train_iter)
  #   print(next_train['label'])
  # exit(0)
  # Create data normalizer and its inverse
  scaler = datasets.get_data_scaler(config)
  inverse_scaler = datasets.get_data_inverse_scaler(config)

  # Setup SDEs
  if config.training.sde.lower() == 'vpsde':
    sde = sde_lib.VPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'subvpsde':
    sde = sde_lib.subVPSDE(beta_min=config.model.beta_min, beta_max=config.model.beta_max, N=config.model.num_scales)
    sampling_eps = 1e-3
  elif config.training.sde.lower() == 'vesde':
    sde = sde_lib.VESDE(sigma_min=config.model.sigma_min, sigma_max=config.model.sigma_max, N=config.model.num_scales)
    sampling_eps = 1e-5
  else:
    raise NotImplementedError(f"SDE {config.training.sde} unknown.")

  # Build one-step training and evaluation functions
  optimize_fn = losses.optimization_manager(config)
  continuous = config.training.continuous
  reduce_mean = config.training.reduce_mean
  likelihood_weighting = config.training.likelihood_weighting
  if hasattr(config.data, 'mix') and config.data.mix:
    train_step_fn = losses.get_step_fn(sde, True, config.data.num_classes, optimize_fn=optimize_fn,
                                      reduce_mean=reduce_mean, continuous=continuous,
                                      likelihood_weighting=likelihood_weighting,
                                      max_step=config.training.n_iters,
                                      score_calib_lambda=config.model.score_calib_lambda,
                                      use_score=config.model.use_score,
                                      dlsm_lambda=dlsm_lambda,
                                      weighting_dlsm=weighting_dlsm,
                                      use_dlsm=use_dlsm,
                                      semi_sc=True)
  else:
    train_step_fn = losses.get_step_fn(sde, True, config.data.num_classes, optimize_fn=optimize_fn,
                                      reduce_mean=reduce_mean, continuous=continuous,
                                      likelihood_weighting=likelihood_weighting,
                                      max_step=config.training.n_iters,
                                      score_calib_lambda=config.model.score_calib_lambda,
                                      dlsm_lambda=dlsm_lambda,
                                      weighting_dlsm=weighting_dlsm,
                                      use_dlsm=use_dlsm,
                                      use_score=config.model.use_score)
  eval_step_fn = losses.get_step_fn(sde, False, config.data.num_classes, optimize_fn=optimize_fn,
                                    reduce_mean=reduce_mean, continuous=continuous,
                                    likelihood_weighting=likelihood_weighting,
                                    max_step=config.training.n_iters,
                                    score_calib_lambda=config.model.score_calib_lambda,
                                    dlsm_lambda=dlsm_lambda,
                                    weighting_dlsm=weighting_dlsm,
                                    use_dlsm=use_dlsm,
                                    use_score=config.model.use_score)

  num_train_steps = config.training.n_iters

  # In case there are multiple hosts (e.g., TPU pods), only log to host 0
  logging.info("Starting training loop at step %d." % (initial_step,))

  for step in range(initial_step, num_train_steps + 1):
    # Convert data to JAX arrays and normalize them. Use ._numpy() to avoid copy.
    if hasattr(config.data, 'mix') and config.data.mix:
      next_labeled_train = next(train_labeled_iter)
      next_unlabeled_train = next(train_unlabeled_iter)
      labeled_batch = torch.from_numpy(next_labeled_train['image']._numpy()).to(config.device).float()
      unlabeled_batch = torch.from_numpy(next_unlabeled_train['image']._numpy()).to(config.device).float()
      batch = torch.concat([labeled_batch, unlabeled_batch])
      targets = torch.from_numpy(next_labeled_train['label']._numpy()).to(config.device)
    else:
      next_train = next(train_iter)
      batch = torch.from_numpy(next_train['image']._numpy()).to(config.device).float()
      targets = torch.from_numpy(next_train['label']._numpy()).to(config.device)
    batch = batch.permute(0, 3, 1, 2)
    batch = scaler(batch).contiguous()
    # Execute one training step
    ce_loss, score_loss, dlsm_loss = train_step_fn(state, batch, targets, score_model)
      
    if step % config.training.log_freq == 0:
      writer.add_scalar("train/ce_loss", ce_loss, step)
      log_info = "step: %d, ce_loss: %.5e" % (step, ce_loss)
      if config.model.use_score:
        writer.add_scalar("train/score_loss", score_loss, step)
        log_info += ", score_loss: %.5e" % (score_loss)
      if hasattr(config.model, 'use_dlsm') and use_dlsm:
        writer.add_scalar("train/dlsm_loss", dlsm_loss, step)
        log_info += ", dlsm_loss: %.5e" % (dlsm_loss)
      writer.add_scalar("train/training_loss", ce_loss+config.model.score_calib_lambda*score_loss+dlsm_lambda*dlsm_loss, step)
      logging.info(log_info)

    # Save a temporary checkpoint to resume training after pre-emption periodically
    if step != 0 and step % config.training.snapshot_freq_for_preemption == 0:
      save_checkpoint(checkpoint_meta_dir, state)

    # Report the loss on an evaluation dataset periodically
    if step % config.training.eval_freq == 0:
      next_eval = next(eval_iter)
      eval_batch = torch.from_numpy(next_eval['image']._numpy()).to(config.device).float()
      eval_targets = torch.from_numpy(next_eval['label']._numpy()).to(config.device)

      eval_batch = eval_batch.permute(0, 3, 1, 2)
      eval_batch = scaler(eval_batch)
      eval_ce_loss, eval_score_loss, eval_dlsm_loss, eval_acc, eval_ori_acc = eval_step_fn(state, eval_batch, eval_targets, score_model)
      
      writer.add_scalar("eval_loss/ce_loss", eval_ce_loss, step)
      log_info = "step: %d, eval_ce_loss: %.5e" % (step, eval_ce_loss)
      if config.model.use_score:
        writer.add_scalar("eval_loss/score_loss", eval_score_loss, step)
        log_info += ", eval_score_loss: %.5e" % (eval_score_loss)
      if use_dlsm:
        writer.add_scalar("eval_loss/dlsm_loss", eval_dlsm_loss, step)
        log_info += ", eval_dlsm_loss: %.5e" % (eval_dlsm_loss)
      writer.add_scalar("eval_loss/total_loss", eval_ce_loss + eval_score_loss + eval_dlsm_loss, step)
      writer.add_scalar("eval_acc/eval_acc", eval_acc, step)
      writer.add_scalar("eval_acc/eval_ori_acc", eval_ori_acc, step)
      log_info += ", eval_acc: %.5e, eval_ori_acc: %.5e" % (eval_acc, eval_ori_acc)
      logging.info(log_info)

    # Save a checkpoint periodically and generate samples if needed
    if step != 0 and step % config.training.snapshot_freq == 0 or step == num_train_steps:
      # Save the checkpoint.
      save_step = step // config.training.snapshot_freq
      save_checkpoint(os.path.join(checkpoint_dir, f'checkpoint_{save_step}.pth'), state)
