# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""All functions related to loss computation and optimization.
"""

import torch
import torch.optim as optim
import math
import numpy as np
from scipy import integrate
from models import utils as mutils
from methods import VESDE, VPSDE
from models import utils_poisson
import datasets
import lamb
import logging


def get_optimizer(config, params):
    """Returns a flax optimizer object based on `config`."""
    beta1 = config.optim.beta1
    if beta1 == 0:
       beta2 = 0.9
    else:
       beta2 = 0.999
       
    if config.optim.optimizer == 'Adam':
        optimizer = optim.Adam(params, lr=config.optim.lr, betas=(beta1, beta2), eps=config.optim.eps, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'Lamb':
        optimizer = lamb.Lamb(params, lr=config.optim.lr, betas=(beta1, beta2), eps=config.optim.eps, weight_decay=config.optim.weight_decay)
    elif config.optim.optimizer == 'SGD':
        optimizer = optim.SGD(params, lr=config.optim.lr, momentum=0.9, weight_decay=config.optim.weight_decay)
    else:
        raise NotImplementedError(f'Optimizer {config.optim.optimizer} not supported yet!')

    return optimizer

def optimization_manager(config):
    """Returns an optimize_fn based on `config`."""

    def optimize_fn(optimizer, params, step, lr=config.optim.lr,
                    warmup=config.optim.warmup,
                    grad_clip=config.optim.grad_clip,
                    grad_clip_mode=config.optim.grad_clip_mode,
                    anneal_rate=config.optim.anneal_rate,
                    anneal_iters=config.optim.anneal_iters):
        """Optimizes with warmup and gradient clipping (disabled if negative)."""
        if warmup > 0:
            for g in optimizer.param_groups:
                g['lr'] = lr * np.minimum(step / warmup, 1.0)
        if step > warmup:
          # if step in np.array(anneal_epochs) * math.ceil(config.data.size / config.training.batch_size):
          if step in anneal_iters:
              for g in optimizer.param_groups:
                  new_lr = g['lr'] * anneal_rate
                  g['lr'] = new_lr
              logging.info("Decaying lr to {}".format(new_lr))
        if grad_clip >= 0:
            if grad_clip_mode == 'norm':
              torch.nn.utils.clip_grad_norm_(params, max_norm=grad_clip)
            elif grad_clip_mode == 'std':
              clip_grad(optimizer, grad_clip)
        optimizer.step()

    return optimize_fn

def clip_grad(optimizer, grad_clip):
    with torch.no_grad():
        for group in optimizer.param_groups:
            for p in group['params']:
                state = optimizer.state[p]
                if 'step' not in state or state['step'] < 1:
                    continue
                step = state['step']
                exp_avg_sq = state['exp_avg_sq']
                _, beta2 = group['betas']
                bound = grad_clip * torch.sqrt(exp_avg_sq / (1 - beta2 ** step)) + 0.1
                p.grad.data.copy_(torch.max(torch.min(p.grad.data, bound), -bound))


def get_perturb_batch_loss_fn(sde, train, reduce_mean=True, continuous=True, eps=1e-10, method_name=None, optimize_fn=None, sampling_fn=None):
    """Create a loss function for training with arbirary SDEs.

    Args:
      sde: An `methods.SDE` object that represents the forward SDE.
      train: `True` for training loss and `False` for evaluation loss.
      reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
      continuous: `Truec` indicates that the model is defined to take continuous time steps. Otherwise it requires
        ad-hoc interpolation to take continuous time steps.
      eps: A `float` number. The smallest time step to sample from.

    Returns:
      A loss function.
    """

    def loss_fn(model, batch, state, train, sample_bool=False, ind_bool=True):
        """Compute the loss function.

        Args:
          model: A PFGM or score model.
          batch: A mini-batch of training data.

        Returns:
          loss: A scalar that represents the average loss value across the mini-batch.
        """

        # perturb sigmas
        if method_name == 'homotopy':

          # Get configs
          data_dim = sde.config.data.channels * sde.config.data.image_size * sde.config.data.image_size
          batch_size = sde.config.training.batch_size
          ensemble_size = sde.config.training.small_batch_size
          if not ind_bool:
            assert batch_size == ensemble_size
          num_particles = sde.config.training.num_particles
          sample_size = sde.config.training.sample_size

          # Get the mini-batch
          # print(batch.shape)
          batch = batch.reshape(batch_size, -1)
          samples_batch = batch[:ensemble_size, :data_dim]

          # Get prior sigma
          # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.image_size * sde.config.data.channels)
          # sigma_prior = math.sqrt(sde.config.data.image_size) * sde.config.data.channels
          # sigma_prior = math.sqrt(sde.config.data.image_size * sde.config.data.channels)
          mean_prior = sde.config.training.mean_prior
          sigma_prior = sde.config.training.sigma_prior

          # Change-of-variable: z = -ln(t)
          z_t = lambda t: -torch.log(t)
          t_z = lambda z: torch.exp(-z)

          with torch.no_grad():
            # Get dec (conditional likelihood) sigma
            sigma_range = (sde.config.training.sigma_min, sde.config.training.sigma_max)
            if sigma_range[0] == sigma_range[1]:
              sigma_dec = sigma_range[0] * torch.ones(ensemble_size, num_particles, 1).to(samples_batch.device)
            else:
              eps_sigma = sde.config.training.eps_sigma
              eps_log_range = (np.log(sigma_range[0] + eps_sigma), np.log(sigma_range[1] + eps_sigma))
              eps_log_random = eps_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (eps_log_range[1] - eps_log_range[0])
              sigma_dec = torch.exp(eps_log_random) - eps_sigma
              if sde.config.training.invert_sigma: sigma_range[1] - sigma_dec

            # Sample t from log-uniform distribution
            t_range = (0, sde.config.training.t_end)
            if sample_bool:
              t_samples = t_range[1] * torch.ones(ensemble_size, num_particles, 1).to(samples_batch.device)
            else:
              if sde.config.training.eps_min == sde.config.training.eps_max:
                eps_t = sde.config.training.eps_min
              else:
                eps_t_range = (sde.config.training.eps_min, sde.config.training.eps_max)
                eps_t = eps_t_range[0] + torch.rand(1).item() * (eps_t_range[1] - eps_t_range[0])
                # eps_t = 10**(-eps_t)
              if eps_t == math.inf:
                  t_samples = t_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (t_range[1] - t_range[0])
              else:
                t_log_range = (np.log(t_range[0] + eps_t), np.log(t_range[1] + eps_t))
                t_log_random = t_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (t_log_range[1] - t_log_range[0])
                t_samples = torch.exp(t_log_random) - eps_t
                if sde.config.training.invert_t: t_samples = t_range[0] - t_samples

            # z_range = (0, z_max)
            # z_log_range = (np.log(z_range[0] + eps_z), np.log(z_range[1] + eps_z))
            # log_random = z_log_range[0] + torch.rand(ensemble_size, num_particles, 1).to(samples_batch.device) * (z_log_range[1] - z_log_range[0])
            # samples_z = torch.exp(log_random) - eps_z

            # Compute enc (posterior) mean and var
            var_prior = sigma_prior**2
            var_dec = t_range[1] * sigma_dec**2
            mean_enc_y = (t_samples * var_prior) / (var_dec + t_samples * var_prior)
            mean_enc_prior = var_dec / (var_dec + t_samples * var_prior)
            var_enc = var_prior * var_dec / (var_dec + t_samples * var_prior)
            t_samples = t_samples.reshape(ensemble_size * num_particles, -1)

            # print(t_samples.min(), t_samples.max())
            # print(var_enc.min(), var_enc.max())
            # plt.plot(sigma_dec.squeeze().cpu(), 4*np.ones_like(sigma_dec.squeeze().cpu()), '.r')
            # plt.plot(t_samples.squeeze().cpu(), 3*np.ones_like(t_samples.squeeze().cpu()), '.k')
            # plt.plot(mean_enc_y.squeeze().cpu(), 2*np.ones_like(mean_enc_y.squeeze().cpu()), '.g')
            # plt.plot(mean_enc_prior.squeeze().cpu(), 1*np.ones_like(mean_enc_prior.squeeze().cpu()), '.g')
            # plt.plot((1/sigma_prior)*np.sqrt(var_enc.squeeze().cpu()), 0*np.ones_like(var_enc.squeeze().cpu()), '.b')
            # plt.show()

            # Perturb data samples with gaussians
            gaussians_x = torch.randn(ensemble_size, num_particles, data_dim).to(samples_batch.device)
            samples_x = mean_enc_y * samples_batch.unsqueeze(dim=1) + mean_enc_prior * mean_prior
            samples_x += gaussians_x * torch.sqrt(var_enc)
            samples_x = samples_x.reshape(ensemble_size * num_particles, -1)
            if sde.config.training.augment_z: 
              z_samples = z_t(t_samples)
              samples_x = torch.cat([samples_x, z_samples], dim=-1)

            Const = data_dim * var_enc + (1 - mean_enc_y).pow(2) * samples_batch.unsqueeze(dim=1).pow(2).sum(dim=-1, keepdim=True)
            Const = Const.reshape(ensemble_size * num_particles)
            # const = var_enc + (1 - mean_enc_y).pow(2) * samples_batch.unsqueeze(dim=1).pow(2)
            # const = const.reshape(ensemble_size * num_particles, -1)

            if sample_bool:
              samples_s, nfe = sampling_fn(model, state, sample_size=sample_size, method='RK23', eps=1e-3, rtol=1e-3, atol=1e-3, inverse_scale=False)
              print("step: %d, nfe: %d" % (state['step'], nfe))
              # print(samples_x.min(), samples_x.max(), samples_s.min(), samples_s.max())
              samples_x[:sample_size] = samples_s

          with torch.enable_grad():
            # Get model function
            net_fn = mutils.get_predict_fn(sde, model, train=train, continuous=continuous)

            # Predict scalar potential
            samples_x.requires_grad = True
            if sde.config.training.augment_t:
              samples_net = torch.cat([samples_x, t_samples], dim=-1)
            else:
              samples_net = samples_x
            psi = net_fn(samples_net).squeeze(dim=-1)

            # Normalize potential by its mean
            # psi -= psi.mean(dim=0, keepdim=True)

            # Compute (backpropagate) N-dimensional Poisson field (gradient)
            drift = torch.autograd.grad(psi, samples_x, torch.ones_like(psi), create_graph=True)[0]

          # Compute drift norm
          if sde.config.training.augment_z:
            x_drift, z_drift = torch.split(drift, [data_dim, 1], dim=-1)
            norm_x = x_drift.pow(2).sum(dim=-1)
            # z_drift = z_drift + torch.reciprocal(t_samples)
            norm_z = z_drift.pow(2).sum(dim=-1)
            Norm = norm_x + norm_z
          else:
            Norm = drift.pow(2).sum(dim=-1)

          with torch.no_grad():
            # Compute Normalized Innovation Squared (Gamma)
            if sde.config.training.augment_z: 
              z_batch = z_t(t_range[1] * torch.ones(batch_size, 1).to(batch.device))
              batch = torch.cat([batch, z_batch], dim=-1)

            if ind_bool:
              distance = batch.unsqueeze(dim=1) - samples_x
            else:
              distance = batch - samples_x

            if sde.config.training.augment_z: 
              distance_x, distance_z = torch.split(distance, [data_dim, 1], dim=-1)
              innovation_x = distance_x.pow(2).sum(dim=-1)
              innovation_z = distance_z.pow(2).sum(dim=-1)
              innovation = innovation_x + innovation_z
            else:
              innovation = distance.pow(2).sum(dim=-1)
            # innovation = innovation.sqrt()
            
            if ind_bool:
              Gamma = innovation.mean(dim=0)
              Gamma -= Gamma.mean(dim=0, keepdim=True)
            else:
              Gamma = innovation - Const.mean(dim=0)
            
            multiplier = 1
            # multiplier *= t_samples.squeeze()
            # multiplier = multiplier.sqrt()
            # multiplier *= math.log(sde.config.training.t_end / sde.config.training.eps_min)

            divisor = 1
            divisor *= sde.config.training.divisor * math.sqrt(data_dim)
            # divisor *= sde.config.training.sigma_min**2 * data_dim**2

            Gamma = Gamma * multiplier / (divisor + eps)

            # print(Gamma.min(), Gamma.max())
            # plt.plot(Gamma.squeeze().cpu(), -1*np.ones_like(Gamma.squeeze().cpu()), '.k')
            # plt.show()

          # Compute sample correlation between potential and NIS
          Cov = torch.sum(Gamma * psi, dim=0) / (ensemble_size * num_particles - 1)
          # Cov = Cov.sum(dim=-1)
          # vars = torch.sum(Gamma.pow(2), dim=-1) * torch.sum(psi.detach().pow(2), dim=-1)
          # Corr = Cov / torch.sqrt(vars + eps)

          Loss = 0.5 * Cov

          Norm = Norm.mean(dim=0)
          Loss += 0.5 * sde.config.optim.gamma * Norm

          Reg = psi.pow(2).mean(dim=0)
          if sde.config.optim.alpha < 0:
            alpha = Norm.detach() / Reg.detach()
          else:
            alpha = sde.config.optim.alpha
          # print("%e" %(Norm.detach() / Reg.detach()))
          Loss += 0.5 * alpha * Reg

          if sde.config.model.name == 'vaebmwrn':
            Spec = model.module.spectral_norm_parallel()
            # print(Spec)
            Loss += 0.5 * sde.config.optim.delta * Spec

          if sde.config.training.augment_z: 
            Nll = norm_z.detach().mean(dim=0)
          else:
            Nll = torch.zeros_like(Loss)

          return Loss, Cov, Reg, Norm, Nll
        
    return loss_fn


def get_step_fn(sde, train, optimize_fn=None, sampling_fn=None, reduce_mean=False, method_name=None):
    """Create a one-step training/evaluation function.

    Args:
      sde: An `methods.SDE` object that represents the forward SDE.
      optimize_fn: An optimization function.
      reduce_mean: If `True`, average the loss across data dimensions. Otherwise sum the loss across data dimensions.
      continuous: `True` indicates that the model is defined to take continuous time steps.

    Returns:
      A one-step function for training or evaluation.
    """

    perturb_loss_fn = get_perturb_batch_loss_fn(sde, train, reduce_mean=reduce_mean, continuous=True, method_name=method_name, optimize_fn=optimize_fn, sampling_fn=sampling_fn)

    def step_fn(state, batch, sample_bool=False, ind_bool=False):
        """Running one step of training or evaluation.

        This function will undergo `jax.lax.scan` so that multiple steps can be pmapped and jit-compiled together
        for faster execution.

        Args:
          state: A dictionary of training information, containing the PFGM or score model, optimizer,
           EMA status, and number of optimization steps.
          batch: A mini-batch of training/evaluation data.

        Returns:
          loss: The average loss value of this state.
        """
        model = state['model']
        if train:
            # Sample bool
            if sde.config.training.sample_freq > 0 and state['step'] % sde.config.training.sample_freq == 0 and state['step'] > sde.config.optim.warmup: 
              sample_bool = True
            Loss, Cov, Corr, Norm, Nll = perturb_loss_fn(model, batch, state, train, sample_bool=sample_bool, ind_bool=ind_bool)
            optimizer = state['optimizer']
            optimizer.zero_grad()
            Loss.backward()
            optimize_fn(optimizer, model.parameters(), step=state['step'])
            state['step'] += 1
            # if state['sigma_max'] > sde.config.training.sigma_clip: 
            #    state['sigma_max'] *= 1 - sde.config.training.sigma_anneal
            # else: 
            #    state['sigma_max'] = sde.config.training.sigma_clip
            state['ema'].update(model.parameters())
        else:
            with torch.no_grad():
                ema = state['ema']
                ema.store(model.parameters())
                ema.copy_to(model.parameters())
                Loss, Cov, Corr, Norm, Nll = perturb_loss_fn(model, batch, state, train)

        return Loss, Cov, Corr, Norm, Nll

    return step_fn
