# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions related to loss computation and optimization.
"""

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch import nn
import numpy as np
from models import utils as mutils
from sde_lib import VESDE, VPSDE
import math
from torch.utils.tensorboard import SummaryWriter
from torch.autograd import grad

def get_optimizer(config, params):
  """Returns a flax optimizer object based on `config`."""
  if config.optim.optimizer == 'Adam':
    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                           weight_decay=config.optim.weight_decay)
  elif config.optim.optimizer == 'AdamW':
    optimizer = optim.AdamW(params, lr=config.optim.lr, weight_decay=config.optim.weight_decay)
  else:
    raise NotImplementedError(
      f'Optimizer {config.optim.optimizer} not supported yet!')

  return optimizer


def optimization_manager(config):
  """Returns an optimize_fn based on `config`."""

  def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                  warmup=config.optim.warmup,
                  grad_clip=config.optim.grad_clip,
                  max_step=1000000):
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
    if warmup > 0:
      for g in optimizer.param_groups:
        # g['lr'] = lr * np.minimum(step / warmup, 1.0)
        if step < warmup:
          g['lr'] = lr * step / warmup
        else:
          # g['lr'] = lr * 0.5 * (1.0 + math.cos(math.pi * (step - warmup) / (max_step - warmup))) # cosine
          g['lr'] = lr * (max_step - step) / (max_step - warmup) # linear anneal
    if grad_clip >= 0:
      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
    optimizer.step()

  return optimize_fn


def get_sde_loss_fn(sde, train, num_classes, reduce_mean=True, continuous=True, likelihood_weighting=True, eps=1e-5, use_score=True, weighting_dlsm=0., use_dlsm=False, semi_sc=False):
  """Create a loss function for training with arbirary SDEs.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    train: `True` for training loss and `False` for evaluation loss.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps. Otherwise it requires
      ad-hoc interpolation to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses
      according to https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended in our paper.
    eps: A `float` number. The smallest time step to sample from.

  Returns:
    A loss function.
  """
  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)
  def loss_fn(model, batch, targets, score_model=None):
    """Compute the loss function.

    Args:
      model: A classifier model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
    bs = batch.shape[0]
    class_fn = mutils.get_score_fn(sde, model, train=train, continuous=continuous)
    t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
    z = torch.randn_like(batch)
    mean, std = sde.marginal_prob(batch, t)
    cur_std = std[:, None, None, None]
    perturbed_data = (mean + cur_std * z).requires_grad_()

    labeled_indices = bs//2 if semi_sc else bs
    logits = class_fn(perturbed_data, t)
    ce_loss = F.cross_entropy(logits[:labeled_indices,:], targets, reduction='mean')
    if use_dlsm:
      # derived from https://github.com/chen-hao-chao/dlsm
      with torch.no_grad():
        score_fn = mutils.get_score_fn(sde, score_model, train=False, continuous=continuous)
        score = score_fn(perturbed_data[:labeled_indices,...], t[:labeled_indices,...])
      log_probs = F.log_softmax(logits[:labeled_indices,:], dim=-1)
      label_mask = F.one_hot(targets, num_classes=num_classes)
      grads_prob_class = torch.autograd.grad(log_probs, perturbed_data[:labeled_indices,:], 
                        grad_outputs=label_mask, create_graph=True)[0]
      dlsm_loss = torch.mean(0.5 * torch.square(grads_prob_class * (std[:labeled_indices,...] ** weighting_dlsm) + score * (std[:labeled_indices,...] ** weighting_dlsm) + z[:labeled_indices,...] * (std[:labeled_indices,...] ** (weighting_dlsm-1)) ))
    else:
      dlsm_loss = 0.
    if use_score:
      # Conditional score
      # joint_score = grad(torch.gather(logits, 1, targets[:, None]).sum(), perturbed_data, create_graph=True)[0]
      # Unconditional score
      joint_score = grad(logits.logsumexp(1).sum(), perturbed_data, create_graph=True)[0]

      losses = torch.square(joint_score * cur_std + z)
      score_loss = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)

    return ce_loss, (torch.mean(score_loss)) if use_score else 0, dlsm_loss, logits
  return loss_fn

def get_step_fn(sde, train, num_classes, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False, max_step=1000000, score_calib_lambda=5, use_score=True, dlsm_lambda=0., weighting_dlsm=0., use_dlsm=False, semi_sc=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  if continuous:
    loss_fn = get_sde_loss_fn(sde, train, num_classes, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting,
                              use_score=use_score,
                              use_dlsm=use_dlsm,
                              semi_sc=semi_sc)
  else:
    assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
    if isinstance(sde, VESDE):
      loss_fn = get_smld_loss_fn(sde, train, reduce_mean=reduce_mean)
    elif isinstance(sde, VPSDE):
      loss_fn = get_ddpm_loss_fn(sde, train, reduce_mean=reduce_mean)
    else:
      raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")

  def step_fn(state, batch, targets, score_model=None):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      state: A dictionary of training information, containing the score model, optimizer,
      batch: A mini-batch of training/evaluation data.

    Returns:
      loss: The average loss value of this state.
    """
    model = state['model']
    if train:
      optimizer = state['optimizer']
      ce_loss, score_loss, dlsm_loss, _ = loss_fn(model, batch, targets, score_model=score_model)
      (ce_loss + score_calib_lambda*score_loss + dlsm_lambda*dlsm_loss).backward()
      # ce_loss.backward()
      optimize_fn(optimizer, model.parameters(), step=state['step'], max_step=max_step)
      optimizer.zero_grad()
      state['step'] += 1
      state['ema'].update(model.parameters())
    else:
      ema = state['ema']
      ema.store(model.parameters())
      ema.copy_to(model.parameters())
      ce_loss, score_loss, dlsm_loss, logits = loss_fn(model, batch, targets, score_model=score_model)
      with torch.no_grad():
        ori_logits = mutils.get_score_fn(sde, model, train=train, continuous=continuous)(batch, torch.tensor([1e-5] * int(batch.shape[0]), device=batch.device))
      ema.restore(model.parameters())
      return ce_loss.item(), (score_loss.item()) if use_score else 0, (dlsm_loss.item()) if use_dlsm else 0, (logits.argmax(dim=-1) == targets).float().mean().item(), (ori_logits.argmax(dim=-1) == targets).float().mean().item()
    return ce_loss.item(), (score_loss.item()) if use_score else 0, (dlsm_loss.item()) if use_dlsm else 0

  return step_fn
