# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pylint: skip-file
# pytype: skip-file
"""Various sampling methods."""
import functools

import torch
import numpy as np
import abc
from scipy import integrate
import sde_lib
import time

import torch.nn.functional as F

_CORRECTORS = {}
_PREDICTORS = {}

def get_conditional_model_fn(model, train=False):

  def model_fn(h, x, labels, side_info):
    
    input = torch.cat([h.unsqueeze(1), x.unsqueeze(1)], dim=1)
    
    if not train:
      model.eval()
      return model(input, side_info, labels)
    else:
      model.train()
      return model(input, side_info, labels)

  return model_fn

def get_conditional_score_fn(sde, model, train=False):

  model_fn = get_conditional_model_fn(model, train=train)
  if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
    def score_fn(h, x, t, side_info):
      
      labels = t * 999
      score = model_fn(h, x, labels, side_info)
      std = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = -score / std[:, None, None]
      return score

  elif isinstance(sde, sde_lib.VESDE):
    def score_fn(x, t):
      labels = sde.marginal_prob(torch.zeros_like(x), t)[1]

      score = model_fn(x, labels)
      return score

  return score_fn

def ts_score_fn(sde, score_fn, scale=0.01):
  
  rsde = sde.reverse(score_fn)
  
  def energy(h, x, t, side_info):
    restored = rsde.restore(h, x, t, side_info)
    
    # noise = score_fn(h, x, t, side_info)
    # mean, std = sde.marginal_prob(h, t)
    # restored = (x-mean)/std[:,None,None]
    
    
    # mask = torch.ones_like(h)
    # mask[h==0] = 0
    # energy = F.mse_loss(restored*mask, h).sum()
    energy = F.mse_loss(restored, h, reduction="none")[h!=0].sum()
    return energy

  def ts_guidence(h, x, t, side_info):
    
    # timestep = (t * (sde.N - 1) / sde.T).long()
    # alpha = sde.alphas.to(x.device)[timestep]
    # std = sde.marginal_prob(torch.zeros_like(x), t)[1]
    # base_scale = std**2/alpha
    
    with torch.enable_grad():
      x.requires_grad_(True)
      Ex = energy(h, x, t, side_info)
      score = -torch.autograd.grad(Ex, x)[0]
    # score = (scale * base_scale)[:,None,None] * score + score_fn(h, x, t, side_info)
    score = scale * score + score_fn(h, x, t, side_info)
    return score
  return ts_guidence

def register_predictor(cls=None, *, name=None):
  """A decorator for registering predictor classes."""

  def _register(cls):
    if name is None:
      local_name = cls.__name__
    else:
      local_name = name
    if local_name in _PREDICTORS:
      raise ValueError(f'Already registered model with name: {local_name}')
    _PREDICTORS[local_name] = cls
    return cls

  if cls is None:
    return _register
  else:
    return _register(cls)


def register_corrector(cls=None, *, name=None):
  """A decorator for registering corrector classes."""

  def _register(cls):
    if name is None:
      local_name = cls.__name__
    else:
      local_name = name
    if local_name in _CORRECTORS:
      raise ValueError(f'Already registered model with name: {local_name}')
    _CORRECTORS[local_name] = cls
    return cls

  if cls is None:
    return _register
  else:
    return _register(cls)


def get_predictor(name):
  return _PREDICTORS[name]


def get_corrector(name):
  return _CORRECTORS[name]


def get_sampling_fn(sde, eps, ode_sampling=False, denoise=False, predictor='euler_maruyama', corrector='none', device='cuda:0'):

  predictor = get_predictor(predictor.lower())
  corrector = get_corrector(corrector.lower())
  if ode_sampling:
    sampling_fn = get_ode_sampler(sde=sde, denoise=denoise, rtol=1e-5, atol=1e-5, method='RK45', eps=eps, device=device)
  else:
    sampling_fn = get_pc_sampler(sde=sde, predictor=predictor, corrector=corrector,
                                  snr=0.16, n_steps=1, eps=eps, device=device)

  return sampling_fn


class Predictor(abc.ABC):
  """The abstract class for a predictor algorithm."""

  def __init__(self, sde, score_fn):
    super().__init__()
    self.sde = sde
    self.rsde = sde.reverse(score_fn)
    self.score_fn = score_fn

  @abc.abstractmethod
  def update_fn(self, x, t):
    """One update of the predictor.

    Args:
      x: A PyTorch tensor representing the current state
      t: A Pytorch tensor representing the current time step.

    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """
    pass

  @abc.abstractmethod
  def conditional_update_fn(self, h, x, t, side_info):
    """One update of the corrector.

    Args:
      x: A PyTorch tensor representing the current state
      t: A PyTorch tensor representing the current time step.

    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """
    pass

class Corrector(abc.ABC):
  """The abstract class for a corrector algorithm."""

  def __init__(self, sde, score_fn, snr, n_steps):
    super().__init__()
    self.sde = sde
    self.score_fn = score_fn
    self.snr = snr
    self.n_steps = n_steps

  @abc.abstractmethod
  def update_fn(self, x, t):
    """One update of the corrector.

    Args:
      x: A PyTorch tensor representing the current state
      t: A PyTorch tensor representing the current time step.

    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """
    pass

  @abc.abstractmethod
  def conditional_update_fn(self, h, x, t, side_info):
    """One update of the corrector.

    Args:
      x: A PyTorch tensor representing the current state
      t: A PyTorch tensor representing the current time step.

    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """
    pass

@register_predictor(name='euler_maruyama')
class EulerMaruyamaPredictor(Predictor):
  def __init__(self, sde, score_fn):
    super().__init__(sde, score_fn)

  def update_fn(self, x, t):
    dt = -1. / self.rsde.N
    z = torch.randn_like(x)
    drift, diffusion = self.rsde.sde(x, t)
    x_mean = x + drift * dt
    x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
    return x, x_mean

  def conditional_update_fn(self, h, x, t, side_info):
    dt = -1. / self.rsde.N
    z = torch.randn_like(x)
    drift, diffusion = self.rsde.conditional_sde(h, x, t, side_info)
    x_mean = x + drift * dt
    x = x_mean + diffusion[:, None, None] * np.sqrt(-dt) * z
    return x, x_mean


@register_predictor(name='reverse_diffusion')
class ReverseDiffusionPredictor(Predictor):
  def __init__(self, sde, score_fn):
    super().__init__(sde, score_fn)

  def update_fn(self, x, t):
    f, G = self.rsde.discretize(x, t)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + G[:, None, None] * z
    return x, x_mean

  def conditional_update_fn(self, h, x, t, side_info):
    f, G = self.rsde.conditional_discretize(h, x, t, side_info)
    z = torch.randn_like(x)
    x_mean = x - f
    x = x_mean + G[:, None, None] * z
    return x, x_mean


@register_predictor(name='ancestral_sampling')
class AncestralSamplingPredictor(Predictor):
  """The ancestral sampling predictor. Currently only supports VE/VP SDEs."""

  def __init__(self, sde, score_fn):
    super().__init__(sde, score_fn)
    if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE):
      raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

  def vesde_update_fn(self, x, t):
    sde = self.sde
    timestep = (t * (sde.N - 1) / sde.T).long()
    sigma = sde.discrete_sigmas[timestep]
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])
    score = self.score_fn(x, t)
    x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[:, None, None]
    std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
    noise = torch.randn_like(x)
    x = x_mean + std[:, None, None] * noise
    return x, x_mean
  def conditional_vesde_update_fn(self, h, x, t):
    sde = self.sde
    timestep = (t * (sde.N - 1) / sde.T).long()
    sigma = sde.discrete_sigmas[timestep]
    adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])
    score = self.score_fn(h, x, t)
    x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[:, None, None]
    std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
    noise = torch.randn_like(x)
    x = x_mean + std[:, None, None] * noise
    return x, x_mean

  def vpsde_update_fn(self, x, t):
    sde = self.sde
    timestep = (t * (sde.N - 1) / sde.T).long()
    beta = sde.discrete_betas.to(t.device)[timestep]
    score = self.score_fn(x, t)
    x_mean = (x + beta[:, None, None] * score) / torch.sqrt(1. - beta)[:, None, None]
    noise = torch.randn_like(x)
    x = x_mean + torch.sqrt(beta)[:, None, None] * noise
    return x, x_mean
  def conditional_vpsde_update_fn(self, h, x, t):
    sde = self.sde
    timestep = (t * (sde.N - 1) / sde.T).long()
    beta = sde.discrete_betas.to(t.device)[timestep]
    score = self.score_fn(h, x, t)
    x_mean = (x + beta[:, None, None] * score) / torch.sqrt(1. - beta)[:, None, None]
    noise = torch.randn_like(x)
    x = x_mean + torch.sqrt(beta)[:, None, None] * noise
    return x, x_mean

  def update_fn(self, x, t):
    if isinstance(self.sde, sde_lib.VESDE):
      return self.vesde_update_fn(x, t)
    elif isinstance(self.sde, sde_lib.VPSDE):
      return self.vpsde_update_fn(x, t)

  def conditional_update_fn(self, h, x, t, side_info):
    if isinstance(self.sde, sde_lib.VESDE):
      return self.conditional_vesde_update_fn(h, x, t, side_info)
    elif isinstance(self.sde, sde_lib.VPSDE):
      return self.conditional_vpsde_update_fn(h, x, t, side_info)


@register_predictor(name='none')
class NonePredictor(Predictor):
  """An empty predictor that does nothing."""

  def __init__(self, sde, score_fn):
    pass

  def update_fn(self, x, t):
    return x, x
  def conditional_update_fn(self, h, x, t, side_info):
    return x, x

@register_corrector(name='langevin')
class LangevinCorrector(Corrector):
  def __init__(self, sde, score_fn, snr, n_steps):
    super().__init__(sde, score_fn, snr, n_steps)
    if not isinstance(sde, sde_lib.VPSDE) \
        and not isinstance(sde, sde_lib.VESDE) \
        and not isinstance(sde, sde_lib.subVPSDE):
      raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

  def update_fn(self, x, t):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    for i in range(n_steps):
      grad = score_fn(x, t)
      noise = torch.randn_like(x)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None] * grad
      x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise

    return x, x_mean

  def conditional_update_fn(self, h, x, t, side_info):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    for i in range(n_steps):
      grad = score_fn(h, x, t, side_info)
      noise = torch.randn_like(x)
      grad_norm = torch.norm(grad.reshape(grad.shape[0], -1), dim=-1).mean()
      noise_norm = torch.norm(noise.reshape(noise.shape[0], -1), dim=-1).mean()
      step_size = (target_snr * noise_norm / grad_norm) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None] * grad
      x = x_mean + torch.sqrt(step_size * 2)[:, None, None] * noise

    return x, x_mean


@register_corrector(name='ald')
class AnnealedLangevinDynamics(Corrector):
  """The original annealed Langevin dynamics predictor in NCSN/NCSNv2.

  We include this corrector only for completeness. It was not directly used in our paper.
  """

  def __init__(self, sde, score_fn, snr, n_steps):
    super().__init__(sde, score_fn, snr, n_steps)
    if not isinstance(sde, sde_lib.VPSDE) \
        and not isinstance(sde, sde_lib.VESDE) \
        and not isinstance(sde, sde_lib.subVPSDE):
      raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")

  def update_fn(self, x, t):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    std = self.sde.marginal_prob(x, t)[1]

    for i in range(n_steps):
      grad = score_fn(x, t)
      noise = torch.randn_like(x)
      step_size = (target_snr * std) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None] * grad
      x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None]

    return x, x_mean

  def conditional_update_fn(self, h, x, t, side_info):
    sde = self.sde
    score_fn = self.score_fn
    n_steps = self.n_steps
    target_snr = self.snr
    if isinstance(sde, sde_lib.VPSDE) or isinstance(sde, sde_lib.subVPSDE):
      timestep = (t * (sde.N - 1) / sde.T).long()
      alpha = sde.alphas.to(t.device)[timestep]
    else:
      alpha = torch.ones_like(t)

    std = self.sde.marginal_prob(x, t)[1]

    for i in range(n_steps):
      grad = score_fn(h, x, t, side_info)
      noise = torch.randn_like(x)
      step_size = (target_snr * std) ** 2 * 2 * alpha
      x_mean = x + step_size[:, None, None] * grad
      x = x_mean + noise * torch.sqrt(step_size * 2)[:, None, None]

    return x, x_mean


@register_corrector(name='none')
class NoneCorrector(Corrector):
  """An empty corrector that does nothing."""

  def __init__(self, sde, score_fn, snr, n_steps):
    pass

  def update_fn(self, x, t):
    return x, x

  def conditional_update_fn(self, h, x, t, side_info):
    return x, x


def shared_predictor_update_fn(sde, model, predictor, eps, conditional, ts_diff=False):
  """A wrapper that configures and returns the update function of predictors."""
  score_fn = get_conditional_score_fn(sde, model, train=False)
  if ts_diff:
    score_fn = ts_score_fn(sde, score_fn)
  if predictor is None:
    # Corrector-only sampler
    predictor_obj = NonePredictor(sde, score_fn)
  else:
    predictor_obj = predictor(sde, score_fn)
  if conditional:
    return predictor_obj.conditional_update_fn
  else:
    return predictor_obj.update_fn


def shared_corrector_update_fn(sde, model, corrector, snr, n_steps, conditional, ts_diff=True):
  """A wrapper tha configures and returns the update function of correctors."""
  score_fn = get_conditional_score_fn(sde, model, train=False)
  if ts_diff:
    score_fn = ts_score_fn(sde, score_fn)
  
  if corrector is None:
    # Predictor-only sampler
    corrector_obj = NoneCorrector(sde, score_fn, snr, n_steps)
  else:
    corrector_obj = corrector(sde, score_fn, snr, n_steps)
  if conditional:
    return corrector_obj.conditional_update_fn
  else:
    return corrector_obj.update_fn

def to_flattened_numpy(x):
  """Flatten a torch tensor `x` and convert it to numpy."""
  return x.detach().cpu().numpy().reshape((-1,))


def from_flattened_numpy(x, shape):
  """Form a torch tensor with the given `shape` from a flattened numpy array `x`."""
  return torch.from_numpy(x.reshape(shape))

def get_pc_sampler(sde, predictor, corrector, snr,
                   n_steps=1, eps = 1e-5, device='cuda'):

  def pc_sampler(conditional_model, condition, side_info, ts_diff=False):
    """ The PC sampler funciton.

    Args:
      model: A score model.
    Returns:
      Samples, number of function evaluations.
    """
    conditional_predictor_update_fn = shared_predictor_update_fn(sde=sde,
                                                    model=conditional_model,
                                                    predictor=predictor,
                                                    eps = eps,
                                                    conditional=True,
                                                    ts_diff=ts_diff)
    conditional_corrector_update_fn = shared_corrector_update_fn(sde=sde,
                                                    model=conditional_model,
                                                    corrector=corrector,
                                                    snr=snr,
                                                    n_steps=n_steps,
                                                    conditional=True,
                                                    ts_diff=ts_diff)
    score_fn = get_conditional_score_fn(sde, conditional_model, train=False)
    # _ts_score_fn = ts_score_fn(sde,score_fn)
    with torch.no_grad():
      end = 1/sde.N
      start = time.time()
      shape = condition.shape
      x = sde.prior_sampling(shape).to(device)
      timesteps = torch.linspace(sde.T, eps, sde.N, device=device)
      # timesteps = torch.linspace(sde.T, end, sde.N, device=device)
      for i in range(sde.N):
        t = timesteps[i]
        vec_t = torch.ones(shape[0], device=t.device) * t
        mean, std = sde.marginal_prob(condition, vec_t)
        # z = - std[:, None, None] * _ts_score_fn(condition, x, vec_t, side_info)
        # z = - std[:, None, None] * score_fn(condition, x, vec_t, side_info)
        # z = torch.randn_like(x)
        # noisy_data = mean + std[:, None, None] * z
        # x[condition!=0] = noisy_data[condition!=0]
        x, x_mean = conditional_predictor_update_fn(condition, x, vec_t, side_info)
        x, x_mean = conditional_corrector_update_fn(condition, x, vec_t, side_info)
        # if i % 100 == 0 :
        #   print(i)
      # print('#' * 90)
      # print("Sampling finished ! Takes {} seconds ".format(time.time() - start))
      # print('#' * 90)
    return x, sde.N * (n_steps + 1)
  return pc_sampler

def get_ode_sampler(sde, denoise=False, rtol=1e-5, atol=1e-5, method='RK45', eps=1e-3, device='cuda'):

  def denoise_update_fn(model, prev, x, side_info):
    score_fn = get_conditional_score_fn(sde, model, train=False)
    # Reverse diffusion predictor for denoising
    predictor_obj = ReverseDiffusionPredictor(sde, score_fn)
    vec_eps = torch.ones(x.shape[0], device=x.device) * eps
    _, x = predictor_obj.conditional_update_fn(prev, x, vec_eps, side_info)
    return x

  def drift_fn(model, prev, x, t, side_info):
    score_fn = get_conditional_score_fn(sde, model, train=False)
    rsde = sde.reverse(score_fn)
    return rsde.conditional_sde(prev, x, t, side_info)[0]

  def ode_sampler(model, condition, side_info):

    with torch.no_grad():

      shape = condition.shape
      x = sde.prior_sampling(shape).to(device)
      prev = condition
      result_nfe = 0

      def ode_func(t, x):
        x = from_flattened_numpy(x, shape).to(device).type(torch.float32)
        vec_t = torch.ones(shape[0], device=x.device) * t
        drift = drift_fn(model, prev, x, vec_t, side_info)
        return to_flattened_numpy(drift)

      # Black-box ODE solver for the probability flow ODE
      solution = integrate.solve_ivp(ode_func, (sde.T, eps), to_flattened_numpy(x), rtol=rtol, atol=atol, method=method)
      nfe = solution.nfev
      x = torch.tensor(solution.y[:, -1]).reshape(shape).to(device).type(torch.float32)
      # if denoise:
      #   x = denoise_update_fn(model, prev, x, side_info)

      result_nfe += nfe

      return x, result_nfe

  return ode_sampler