# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions related to loss computation and optimization.
"""

import torch
import torch.optim as optim
import numpy as np
import utils as mutils
from sde_lib import RectifiedFlow
# from . 


def get_optimizer(config, params):
  """Returns a flax optimizer object based on `config`."""
  if config.optim.optimizer == 'Adam':
    optimizer = optim.Adam(params, lr=config.optim.lr, betas=(config.optim.beta1, 0.999), eps=config.optim.eps,
                           weight_decay=config.optim.weight_decay)
  else:
    raise NotImplementedError(
      f'Optimizer {config.optim.optimizer} not supported yet!')

  return optimizer


def optimization_manager(config):
  """Returns an optimize_fn based on `config`."""

  def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                  warmup=config.optim.warmup,
                  grad_clip=config.optim.grad_clip):
    """Optimizes with warmup and gradient clipping (disabled if negative)."""
    if warmup > 0:
      for g in optimizer.param_groups:
        g['lr'] = lr * np.minimum(step / warmup, 1.0)
    if grad_clip >= 0:
      torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
    optimizer.step()

  return optimize_fn


def get_rectified_flow_loss_fn(sde, train, reduce_mean=True, eps=1e-3):
  """Create a loss function for training with rectified flow.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    train: `True` for training loss and `False` for evaluation loss.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    eps: A `float` number. The smallest time step to sample from.

  Returns:
    A loss function.
  """

  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, data):
    """Compute the loss function.

    Args:
      model: A velocity model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
    (attention_mask, last_layer_hidden_states, gt_action, action_prob, position_embeddings, mixed) = data#
    batch = action_prob
    attention_mask = attention_mask
    gt_action = gt_action
    last_layer_hidden_states = last_layer_hidden_states
    position_embeddings = position_embeddings

    if sde.reflow_flag:
        z0 = batch[0]
        data = batch[1]
        batch = data.detach().clone()

    else:
        z0 = sde.get_z0(batch).to(batch.device)
    
    if sde.reflow_flag:
        if sde.reflow_t_schedule=='t0': ### distill for t = 0 (k=1)
            t = torch.zeros(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif sde.reflow_t_schedule=='t1': ### reverse distill for t=1 (fast embedding)
            t = torch.ones(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif sde.reflow_t_schedule=='uniform': ### train new rectified flow with reflow
            t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif type(sde.reflow_t_schedule)==int: ### k > 1 distillation
            t = torch.randint(0, sde.reflow_t_schedule, (batch.shape[0], ), device=batch.device) * (sde.T - eps) / sde.reflow_t_schedule + eps
        else:
            assert False, 'Not implemented'
    else:
        ### standard rectified flow loss
        t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
    
   
    t_expand = t.view(-1, 1).repeat(1, batch.shape[1])
    perturbed_data = t_expand * batch + (1.-t_expand) * z0
    target = batch - z0 
    
    model_fn = mutils.get_model_fn(model, train=train)
    data = (perturbed_data, t*999, last_layer_hidden_states, attention_mask,position_embeddings)#
    score = model_fn(data) ### Copy from models/utils.py 
    if mixed:
      f_accuracy = (torch.argmax(score+z0, dim = -1)[:32]==gt_action[:32]).sum()/32
      g_accuracy = (torch.argmax(score+z0, dim = -1)[32:]==gt_action[32:]).sum()/32
      h_accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
      accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
    else:
      f_accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
      g_accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
      h_accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
      accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
    if sde.reflow_flag:
        ### we found LPIPS loss is the best for distillation when k=1; but good to have a try
        if sde.reflow_loss=='l2':
            ### train new rectified flow with reflow or distillation with L2 loss
            losses = torch.square(score - target)
        elif sde.reflow_loss=='lpips':
            assert sde.reflow_t_schedule=='t0'
            losses = sde.lpips_model(z0 + score, batch)
        elif sde.reflow_loss=='lpips+l2':
            assert sde.reflow_t_schedule=='t0'
            lpips_losses = sde.lpips_model(z0 + score, batch).view(batch.shape[0], 1)
            l2_losses = torch.square(score - target).view(batch.shape[0], -1).mean(dim=1, keepdim=True)
            losses = lpips_losses + l2_losses
        else:
            assert False, 'Not implemented'
    else:
        losses = torch.square(score - target)
    
    losses = reduce_op(losses.reshape(losses.shape[0], -1), dim=-1)

    loss = torch.mean(losses)
   
    del data
   
    return loss, f_accuracy, g_accuracy, h_accuracy, accuracy

  return loss_fn



def get_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  if continuous:
    loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting)
  else:
    assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
    if isinstance(sde, RectifiedFlow):
      loss_fn = get_rectified_flow_loss_fn(sde, train, reduce_mean=reduce_mean)
    else:
      raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")

  def step_fn(state, batch):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      state: A dictionary of training information, containing the score model, optimizer,
       EMA status, and number of optimization steps.
      batch: A mini-batch of training/evaluation data.

    Returns:
      loss: The average loss value of this state.
    """
    model = state['model']
    if train:
      optimizer = state['optimizer']
      optimizer.zero_grad()
      loss, f_accuracy, g_accuracy, k_accuracy, accuracy = loss_fn(model, batch)
      loss.backward()
      optimize_fn(optimizer, model.parameters(), step=state['step'])
      state['step'] += 1
     
    else:
      with torch.no_grad():
      
        loss, f_accuracy, g_accuracy, k_accuracy, accuracy  = loss_fn(model, batch)
       

    return loss, f_accuracy, g_accuracy, k_accuracy, accuracy

  return step_fn


def get_n_step_fn(sde, train, optimize_fn=None, reduce_mean=False, continuous=True, likelihood_weighting=False):
  """Create a one-step training/evaluation function.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    optimize_fn: An optimization function.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    continuous: `True` indicates that the model is defined to take continuous time steps.
    likelihood_weighting: If `True`, weight the mixture of score matching losses according to
      https://arxiv.org/abs/2101.09258; otherwise use the weighting recommended by our paper.

  Returns:
    A one-step function for training or evaluation.
  """
  if continuous:
    loss_fn = get_sde_loss_fn(sde, train, reduce_mean=reduce_mean,
                              continuous=True, likelihood_weighting=likelihood_weighting)
  else:
    assert not likelihood_weighting, "Likelihood weighting is not supported for original SMLD/DDPM training."
    if isinstance(sde, RectifiedFlow):
      loss_fn = get_n_rectified_flow_loss_fn(sde, train, reduce_mean=reduce_mean)
    else:
      raise ValueError(f"Discrete training for {sde.__class__.__name__} is not recommended.")

  def step_fn(state, batch):
    """Running one step of training or evaluation.

    This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
    for faster execution.

    Args:
      state: A dictionary of training information, containing the score model, optimizer,
       EMA status, and number of optimization steps.
      batch: A mini-batch of training/evaluation data.

    Returns:
      loss: The average loss value of this state.
    """
    model = state['model']
    if train:
      optimizer = state['optimizer']
      optimizer.zero_grad()
      loss, accuracy = loss_fn(model, batch)
      loss.backward()
      optimize_fn(optimizer, model.parameters(), step=state['step'])
      state['step'] += 1
      
    else:
      with torch.no_grad():
        accuracy = loss_fn(model, batch)
      

    return accuracy

  return step_fn


def get_n_rectified_flow_loss_fn(sde, train, reduce_mean=True, eps=1e-3):
  """Create a loss function for training with rectified flow.

  Args:
    sde: An `sde_lib.SDE` object that represents the forward SDE.
    train: `True` for training loss and `False` for evaluation loss.
    reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
    eps: A `float` number. The smallest time step to sample from.

  Returns:
    A loss function.
  """

  reduce_op = torch.mean if reduce_mean else lambda *args, **kwargs: 0.5 * torch.sum(*args, **kwargs)

  def loss_fn(model, data):
    """Compute the loss function.

    Args:
      model: A velocity model.
      batch: A mini-batch of training data.

    Returns:
      loss: A scalar that represents the average loss value across the mini-batch.
    """
   

    (attention_mask, last_layer_hidden_states, gt_action, position_embeddings) = data#
    batch = attention_mask[:,0,:4]
    attention_mask = attention_mask
    gt_action = gt_action
    last_layer_hidden_states = last_layer_hidden_states
    position_embeddings = position_embeddings
    

    if sde.reflow_flag:
        z0 = batch[0]
        data = batch[1]
        batch = data.detach().clone()

    else:
        z0 = sde.get_z0(batch).to(batch.device)
    
    if sde.reflow_flag:
        if sde.reflow_t_schedule=='t0': ### distill for t = 0 (k=1)
            t = torch.zeros(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif sde.reflow_t_schedule=='t1': ### reverse distill for t=1 (fast embedding)
            t = torch.ones(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif sde.reflow_t_schedule=='uniform': ### train new rectified flow with reflow
            t = torch.rand(batch.shape[0], device=batch.device) * (sde.T - eps) + eps
        elif type(sde.reflow_t_schedule)==int: ### k > 1 distillation
            t = torch.randint(0, sde.reflow_t_schedule, (batch.shape[0], ), device=batch.device) * (sde.T - eps) / sde.reflow_t_schedule + eps
        else:
            assert False, 'Not implemented'
    else:
        ### standard rectified flow loss
        t = torch.zeros(batch.shape[0], device=batch.device) * (sde.T - eps) + eps

    t_expand = t.view(-1, 1).repeat(1, batch.shape[1])
    perturbed_data =  (1.-t_expand) * z0
    
    
    model_fn = mutils.get_model_fn(model, train=train)
    data = (perturbed_data, t*999, last_layer_hidden_states, attention_mask, position_embeddings)#
    score = model_fn(data) ### Copy from models/utils.py 
   
    accuracy = (torch.argmax(score+z0, dim = -1)==gt_action).sum()/gt_action.shape[0]
   
    
    del data
    
    return accuracy

  return loss_fn


