import abc
import torch
import numpy as np
from functools import *

def get_t2dt_map(discretisation):
    """Returns a map from time to negative time step dt."""
    #discretisation sequence is ordered from biggest time to smallest time
    map_t_to_negative_dt = {}
    steps = len(discretisation)
    for i in range(steps):
        if i <= steps-2:
            map_t_to_negative_dt[discretisation[i].item()] = discretisation[i+1] - discretisation[i] # keys must be floats, not tensors with device. Values can be tensors
        elif i==steps-1:
            map_t_to_negative_dt[discretisation[i].item()] = map_t_to_negative_dt[discretisation[i-1].item()]

    return map_t_to_negative_dt

def find_dt(map_t_to_negative_dt, discretisation, t):
  if t.item() in map_t_to_negative_dt.keys():
    return map_t_to_negative_dt[t.item()]
  else:
    closest_t_key = discretisation[np.argmin(np.abs(discretisation-t.item()))]
    return map_t_to_negative_dt[closest_t_key.item()]


class Predictor(abc.ABC):
  """The abstract class for a predictor algorithm."""

  def __init__(self, sde, probability_flow=False):
    super().__init__()
    self.sde = sde

  @abc.abstractmethod
  def update_fn(self, x, t):
    """One update of the predictor.
    Args:
      x: A PyTorch tensor representing the current state
      t: A Pytorch tensor representing the current time step.
    Returns:
      x: A PyTorch tensor of the next state.
      x_mean: A PyTorch tensor. The next state without random noise. Useful for denoising.
    """
    pass

class EulerMaruyamaPredictor(Predictor):
  def __init__(self, sde, probability_flow=False):
    """Predictor class to perform P part of PC-sampling. 
    Args: 
      sde: SDE object with forward and reverse equations
      score_fn: score prediction function

      probability_flow: whether to use probability flow ODE, must be False for the default PC-Sampling"""
    super().__init__(sde, probability_flow)
    self.probability_flow=probability_flow

  def set_discretisation(self, discretisation):
      t2dt_map = get_t2dt_map(discretisation)
      self.find_dt = partial(find_dt, t2dt_map, discretisation)
      
  def change_probability_flow(self, probability_flow):
    """Switch from SDE to ODE working mode and vice verse."""
    self.probability_flow = probability_flow

  def update_fn(self, x, t, eta=1.0):
    """One update of the predictor in the sampling. Eta is the hacky 'noise scale', the 
    curiosity under investigation. Should be 1.0 for sampling to be mathematically correct."""
  
    # dt = torch.tensor(self.inverse_step_fn(t[0].cpu().item())).type_as(t) #dt = -(1-self.sde.sampling_eps) / self.rsde.N
    dt = torch.tensor(self.find_dt(t[0])).type_as(t)
    drift, diffusion = self.sde.get_reverse_sde_coefficients(x, t, self.probability_flow)
    x_mean = x + drift * dt
      
    if self.probability_flow:
      return x_mean, x_mean
    else:
      z = torch.randn_like(x)
      x = x_mean + diffusion[(...,) + (None,) * len(x.shape[1:])] * torch.sqrt(-dt) * z * eta
      return x, x_mean

#name='reverse_diffusion') # this entire thing might be kicked out
# class ReverseDiffusionPredictor(Predictor):
#   def __init__(self, sde, score_fn, probability_flow=False):
#     super().__init__(sde, score_fn, probability_flow)

#   def update_fn(self, x, t):
#     f, G = self.rsde.discretize(x, t)
#     z = torch.randn_like(x)
#     x_mean = x - f
#     x = x_mean + G[(...,) + (None,) * len(x.shape[1:])] * z
#     return x, x_mean

# register_predictor(name='none')
class NonePredictor(Predictor):
  """An empty predictor that does nothing."""

  def __init__(self, sde, score_fn, discretisation, probability_flow=False):
    pass

  def update_fn(self, x, t):
    x_mean = x
    return x, x_mean
  
# ===================== NOT SUPPORTED YET =====================
# @register_predictor(name='conditional_euler_maruyama')
# class conditionalEulerMaruyamaPredictor(Predictor):
#   def __init__(self, sde, score_fn, probability_flow=False, discretisation=None):
#     super().__init__(sde, score_fn, probability_flow, discretisation)
#     self.probability_flow=probability_flow

#   def update_fn(self, x, y, t):
#     dt = torch.tensor(self.inverse_step_fn(t[0].cpu().item())).type_as(t) #dt = -(1-self.sde.sampling_eps) / self.rsde.N
#     z = torch.randn_like(x)
#     drift, diffusion = self.sde.get_reverse_sde_coefficients(x, y, t)
#     x_mean = x + drift * dt
#     x = x_mean + diffusion[(...,) + (None,) * len(x.shape[1:])] * torch.sqrt(-dt) * z
#     return x, x_mean

# @register_predictor(name='conditional_reverse_diffusion')
# class conditionalReverseDiffusionPredictor(Predictor):
#   def __init__(self, sde, score_fn, probability_flow=False):
#     super().__init__(sde, score_fn, probability_flow)

#   def update_fn(self, x, y, t):
#     f, G = self.rsde.discretize(x, y, t)
#     z = torch.randn_like(x)
#     x_mean = x - f
#     x = x_mean + G[(...,) + (None,) * len(x.shape[1:])] * z
#     return x, x_mean


# @register_predictor(name='ancestral_sampling')
# class AncestralSamplingPredictor(Predictor):
#   """The ancestral sampling predictor. Currently only supports VE/VP SDEs."""

#   def __init__(self, sde, score_fn, probability_flow=False):
#     super().__init__(sde, score_fn, probability_flow)
#     if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE):
#       raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
#     assert not probability_flow, "Probability flow not supported by ancestral sampling"

#   def vesde_update_fn(self, x, t):
#     sde = self.sde
#     timestep = (t * (sde.N - 1) / sde.T).long()
#     sigma = sde.discrete_sigmas[timestep]
#     adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])
#     score = self.score_fn(x, t)
#     x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[(...,) + (None,) * len(x.shape[1:])]
#     std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
#     noise = torch.randn_like(x)
#     x = x_mean + std[(...,) + (None,) * len(x.shape[1:])] * noise
#     return x, x_mean

#   def vpsde_update_fn(self, x, t):
#     sde = self.sde
#     timestep = (t * (sde.N - 1) / sde.T).long()
#     beta = sde.discrete_betas.to(t.device)[timestep]
#     score = self.score_fn(x, t)
#     x_mean = (x + beta[(...,) + (None,) * len(x.shape[1:])] * score) / torch.sqrt(1. - beta)[(...,) + (None,) * len(x.shape[1:])]
#     noise = torch.randn_like(x)
#     x = x_mean + torch.sqrt(beta)[(...,) + (None,) * len(x.shape[1:])] * noise
#     return x, x_mean

#   def update_fn(self, x, t):
#     if isinstance(self.sde, sde_lib.VESDE):
#       return self.vesde_update_fn(x, t)
#     elif isinstance(self.sde, sde_lib.VPSDE):
#       return self.vpsde_update_fn(x, t)

# @register_predictor(name='conditional_ancestral_sampling')
# class conditionalAncestralSamplingPredictor(Predictor):
#   """The ancestral sampling predictor. Currently only supports VE/VP SDEs."""

#   def __init__(self, sde, score_fn, probability_flow=False):
#     super().__init__(sde, score_fn, probability_flow)
#     if not isinstance(sde, sde_lib.VPSDE) and not isinstance(sde, sde_lib.VESDE) \
#       and not isinstance(sde, sde_lib.cVESDE):
#       raise NotImplementedError(f"SDE class {sde.__class__.__name__} not yet supported.")
#     assert not probability_flow, "Probability flow not supported by ancestral sampling"

#   def vesde_update_fn(self, x, y, t):
#     sde = self.sde
#     timestep = (t * (sde.N - 1) / sde.T).long()
#     sigma = sde.discrete_sigmas[timestep]
#     adjacent_sigma = torch.where(timestep == 0, torch.zeros_like(t), sde.discrete_sigmas.to(t.device)[timestep - 1])
#     score = self.score_fn(x, y, t)
#     x_mean = x + score * (sigma ** 2 - adjacent_sigma ** 2)[(...,) + (None,) * len(x.shape[1:])]
#     std = torch.sqrt((adjacent_sigma ** 2 * (sigma ** 2 - adjacent_sigma ** 2)) / (sigma ** 2))
#     noise = torch.randn_like(x)
#     x = x_mean + std[(...,) + (None,) * len(x.shape[1:])] * noise
#     return x, x_mean

#   def vpsde_update_fn(self, x, y, t):
#     sde = self.sde
#     timestep = (t * (sde.N - 1) / sde.T).long()
#     beta = sde.discrete_betas.to(t.device)[timestep]
#     score = self.score_fn(x, y, t)
#     x_mean = (x + beta[(...,) + (None,) * len(x.shape[1:])] * score) / torch.sqrt(1. - beta)[(...,) + (None,) * len(x.shape[1:])]
#     noise = torch.randn_like(x)
#     x = x_mean + torch.sqrt(beta)[(...,) + (None,) * len(x.shape[1:])] * noise
#     return x, x_mean

#   def update_fn(self, x, t):
#     if isinstance(self.sde, sde_lib.VESDE):
#       return self.vesde_update_fn(x, t)
#     elif isinstance(self.sde, sde_lib.VPSDE):
#       return self.vpsde_update_fn(x, t)


# @register_predictor(name='conditional_none')
# class NonePredictor(Predictor):
#   """An empty predictor that does nothing."""

#   def __init__(self, sde, score_fn, probability_flow=False):
#     pass

#   def update_fn(self, x, y, t):
#     return x, x
