import abc
import torch
import numpy as np
import math
from tqdm import tqdm
from easydict import EasyDict as edict
import re

class DiffusionMixture(object):
  def __init__(self, bridge, eta, drift_coeff, schedule, sigma_0, sigma_1, N=1000):
    """Construct a Mixture of Diffusion Bridges.
    Args:
      bridge: type of bridge
      eta: hyperparameter for noise schedule scaling
      schedule: type of noise schedule 
      sigma_0, simga_1: hyperparameters for the noise schedule
      N: number of discretization steps
    """
    super().__init__()
    self.bridge_type = bridge
    self.drift_coeff = drift_coeff
    self.schedule = schedule
    self.sigma_0 = sigma_0
    self.sigma_1 = sigma_1
    self.N = N

    self.eta = eta
    if isinstance(eta, (int, float)):
      self.eta = edict({'eta': eta, 'sigma_0': self.sigma_0 * math.sqrt(eta), 'sigma_1': self.sigma_1 * math.sqrt(eta)})

  def diffusion(self, t):
    if self.schedule == 'sqlinear':
      sigma_t = torch.sqrt((1-t) * self.eta.sigma_0**2 + t * self.eta.sigma_1**2)
    # elif self.schedule == 'linear':
    #   sigma_t = (1-t) * self.sigma_0 + t * self.sigma_1
    # elif self.schedule == 'geometric':
    #   sigma_t = self.sigma_0 * (self.sigma_1 / self.sigma_0)**t
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return sigma_t 

  def eta_scale(self, t):
    if self.schedule == 'sqlinear':
      var_ref = (1-t) * self.eta.sigma_0**2 + t * self.eta.sigma_1**2
      var = (1-t) * self.sigma_0**2 + t * self.sigma_1**2
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return var_ref / var

  def bridge(self, destination):
    bridge_args = {'eta': self.eta, 'drift_coeff': self.drift_coeff, 'schedule': self.schedule, 
                  'sigma_0': self.sigma_0, 'sigma_1': self.sigma_1, 'destination': destination, 'N': self.N}
    if 'BB' in self.bridge_type:
      bridge = BrownianBridge(**bridge_args)
    elif 'OU' in self.bridge_type:
      bridge = OUBridge(**bridge_args)
    else:
      raise NotImplementedError(f'Bridge type {self.bridge_type} not implemented.')
    return bridge


class Bridge(abc.ABC):
  """Diffusion Bridge abstract class. Functions are designed for a mini-batch of inputs."""

  def __init__(self, eta, drift_coeff, schedule, sigma_0, sigma_1, destination, N):
    """Construct an Diffusion Bridge.
    Args:
      N: number of discretization time steps.
    """
    super().__init__()
    self.eta = eta
    self.drift_coeff = drift_coeff
    self.schedule = schedule
    self.sigma_0 = sigma_0
    self.sigma_1 = sigma_1
    self.dest = destination
    self.N = N

  # -------- Do not scale the bridge here --------
  @property
  def T(self):
    return 1.

  def diffusion(self, t):
    if self.schedule == 'sqlinear':
      sigma_t = torch.sqrt((1-t) * self.sigma_0**2 + t * self.sigma_1**2)
    elif self.schedule == 'linear':
      sigma_t = (1-t) * self.sigma_0 + t * self.sigma_1
    elif self.schedule == 'geometric':
      sigma_t = self.sigma_0 * (self.sigma_1 / self.sigma_0)**t
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return sigma_t

  # -------- sigma_t**2 --------
  def variance(self, t):
    if self.schedule == 'sqlinear':
      variance = (1-t) * self.sigma_0**2 + t * self.sigma_1**2
    elif self.schedule == 'linear':
      variance = ((1-t) * self.sigma_0 + t * self.sigma_1)**2
    elif self.schedule == 'geometric':
      variance = self.sigma_0**2 * (self.sigma_1 / self.sigma_0)**(2*t)
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return variance

  # -------- Integrate sigma_t ** 2 from time 0 to t --------
  def beta_t(self, t):
    if self.schedule == 'sqlinear':
      beta_t = t * self.sigma_0**2 - 0.5 * t**2 * (self.sigma_0**2 - self.sigma_1**2)
    elif self.schedule == 'linear':
      sigma_d = self.sigma_1 - self.sigma_0
      beta_t = sigma_d**2 * t**3 / 3 + self.sigma_0 * sigma_d * t**2 + self.sigma_0**2 * t
    elif self.schedule == 'geometric':
      if not self.sigma_0 == self.sigma_1:
        sigma_r = (self.sigma_1 / self.sigma_0)**2
        beta_t = self.sigma_0**2 * (sigma_r**t - 1) / np.log(sigma_r)
      else:
        beta_t = self.sigma_0**2 * t
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return beta_t

  def eta_diffusion(self, t):
    if self.schedule == 'sqlinear':
      sigma_ref_t = torch.sqrt((1-t) * self.eta.sigma_0**2 + t * self.eta.sigma_1**2)
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return sigma_ref_t

  def eta_t(self, t):
    if self.schedule == 'sqlinear':
      var_ref = (1-t) * self.eta.sigma_0**2 + t * self.eta.sigma_1**2
      var = (1-t) * self.sigma_0**2 + t * self.sigma_1**2
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return var_ref / var

  @abc.abstractmethod
  def sde(self, z, t):
    pass

  @abc.abstractmethod
  def marginal_prob(self, z0, t):
    """Parameters to determine the marginal distribution of the Bridge, $p_t(z)$."""
    pass

  @abc.abstractmethod
  def prior_sampling(self, shape, device='cpu'):
    """Generate one sample from the prior distribution, $p_T(z)$."""
    pass


class BrownianBridge(Bridge):
  def __init__(self, eta, drift_coeff, schedule, sigma_0, sigma_1, destination, N=1000):
    """Construct a Brownian Bridge.
    Args:
      destination : terminal point
      N: number of discretization steps
    """
    super().__init__(eta, drift_coeff, schedule, sigma_0, sigma_1, destination, N)

  # -------- sigma_t**2 / (beta_1 - beta_t) --------
  # not scaled
  def drift_time_scaled(self, t):
    if self.schedule == 'sqlinear':
      seps = 1 - (self.sigma_1 / self.sigma_0)**2 
      drift_time_scaled = (1 - t * seps) / (1 - t - 0.5 * seps * (1-t**2))
    elif self.schedule == 'linear':
      seps = (self.sigma_0 - self.sigma_1) / self.sigma_0 
      r = 1. - seps * t
      denom = r - (1-seps)**3/r**2
      drift_time_scaled = 3 * seps / denom
    elif self.schedule == 'geometric':
      if not self.sigma_0 == self.sigma_1:
        sigma_r = (self.sigma_1 / self.sigma_0)**2
        drift_time_scaled = np.log(sigma_r) / (sigma_r**(1.-t) - 1.)
      else:
        drift_time_scaled = 1. / (1. - t)
    else:
      raise NotImplementedError(f"Schedule type: {self.schedule} not supported.")
    return drift_time_scaled

  def sde(self, z, t):
    drift = (self.dest - z) * self.drift_time_scaled(t)[:, None, None]
    drift = drift * (1. + self.eta_t(t))[:, None, None] * 0.5
    diffusion = self.eta_diffusion(t)
    return drift, diffusion

  # -------- mean, std of the perturbation kernel --------
  def marginal_prob(self, z0, t):
    if self.eta.eta == 1:
      beta_1 = self.beta_t(torch.ones_like(t))
      beta = self.beta_t(t)
      mean = (z0 * (beta_1 - beta)[:, None, None] + self.dest * (beta)[:, None, None]) / (beta_1)[:, None, None] 
      std = torch.sqrt((beta_1 - beta) * beta / beta_1)
    else:
      beta_1 = self.beta_t(torch.ones_like(t))
      beta = self.beta_t(t)
      gamma = torch.sqrt( (beta_1 - beta).clamp(0.) / beta_1 )
      rho_coeff1 = (self.eta.sigma_0**2 - self.eta.sigma_1**2) / (self.sigma_0**2 - self.sigma_1**2)
      if self.sigma_1 > 0:
        rho_coeff2 = self.eta.sigma_1**2 / self.sigma_1**2 - rho_coeff1
        if rho_coeff1 > 1.0e-6:
          rho = gamma ** rho_coeff1 * (1 - (2 * self.sigma_1**2 * t/ (self.variance(torch.ones_like(t)) + self.variance(t)) )) ** (rho_coeff2 * 0.5)
        else:
          rho = torch.ones_like(t)
      else:
        rho = gamma ** rho_coeff1 * torch.exp(-(t * self.eta.sigma_1**2) / (self.variance(t)))
      coeff = gamma * rho
      mean = z0 * coeff[:, None, None] + self.dest * (1. - coeff)[:, None, None]
      std = torch.sqrt((beta_1 - beta).clamp(0.)) * torch.sqrt(1. - rho**2)
    return mean, std

  def prior_sampling(self, shape, device='cpu'):
    return torch.randn(*shape, device=device)

  def prior_sampling_sym(self, shape, device='cpu'):
    z = torch.randn(*shape, device=device).triu(1)
    return z + z.transpose(-1,-2)

  # Compute sigma_t / (beta_1 - beta_t)
  def loss_coeff(self, t):
    loss_coeff = self.drift_time_scaled(t) / self.diffusion(t)
    loss_coeff = loss_coeff * (1. + self.eta_t(t)) * 0.5
    return loss_coeff

  # -------- log p_0 for standard normal --------
  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2)) / 2.
    return logps


# -------- Eta-scale not yet implemented. --------
class OUBridge(Bridge):
  def __init__(self, eta, drift_coeff, schedule, sigma_0, sigma_1, destination, N=1000):
    """Construct a Ornstein–Uhlenbeck Bridge.
    Args:
      destination : terminal point
      N: number of discretization steps
    """
    super().__init__(eta, drift_coeff, schedule, sigma_0, sigma_1, destination, N)

    if isinstance(self.drift_coeff, (int, float)):
      self.drift_schedule = 'VP'
    else:
      drift_schedule = re.split(r'[\d\.\d]+', self.drift_coeff)[0]
      drift_coeffs = re.findall(r'[\d\.\d]+', self.drift_coeff)
      if drift_schedule == 'm':
        self.drift_schedule = 'MVP'
        self.drift_coeff = float(drift_coeffs[0])
      elif drift_schedule == 's':
        self.drift_schedule = 'SVP'
        self.coeffs = [float(c) for c in drift_coeffs]
        self.drift_coeff = 2.
      else:
        raise NotImplementedError(f'Drift coeff: {drift_schedule} not implemented.')

  def alpha_t(self, t):
    if self.drift_schedule == 'VP':
      alpha_t = -0.5 * self.drift_coeff 
    elif self.drift_schedule == 'MVP':
      alpha_t = -0.5 * self.drift_coeff * self.variance(t)
    elif self.drift_schedule == 'SVP':
      alpha_t = -0.5 * self.drift_coeff * (-self.coeffs[0] * t + self.coeffs[1])
    else:
      raise NotImplementedError(f"Drift schedule type: {self.drift_schedule} not supported.")
    return alpha_t

  # -------- Integration of alpha_t from time s to t --------
  def alpha_sum(self, s, t):
    if self.drift_schedule == 'VP':
      alpha_sum = -0.5 * self.drift_coeff * (t - s)
    elif self.drift_schedule == 'MVP':
      alpha_sum = -0.5 * self.drift_coeff * (self.beta_t(t) - self.beta_t(s))
    elif self.drift_schedule == 'SVP':
      alpha_sum = -0.5 * self.drift_coeff * (-0.5 * self.coeffs[0] * (t**2 - s**2) + self.coeffs[1]* (t - s))
    else:
      raise NotImplementedError(f"Drift schedule type: {self.drift_schedule} not supported.")
    return alpha_sum

  def a_ou(self, s, t):
    log_coeff = self.alpha_sum(self.beta_t(s), self.beta_t(t))
    return torch.exp(log_coeff)

  def v_ou(self, s, t):
    v_ou = 0.5 * ( self.a_ou(s,t)**2 / self.alpha_t(self.beta_t(s)) - 1. /  self.alpha_t(self.beta_t(t)))
    return v_ou 

  # Compute a_t1 / v_t1
  def a_over_v(self, t):
    ones = torch.ones_like(t)
    a_t1 = self.a_ou(t, ones)
    if self.drift_schedule=='VP':
      result = self.drift_coeff / (1./a_t1 - a_t1)
    else:
      result = a_t1 / self.v_ou(t, ones)
    return result

  # -------- \sigma_t**2 * \nabla_{z} \log p_{1|t}(x|z) --------
  def drift_adjustment(self, z, t):
    ones = torch.ones_like(t)
    a_t1 = self.a_ou(t, ones)
    gamma = a_t1 * self.a_over_v(t)
    adjustment = (self.dest / a_t1[:, None, None] - z ) * (self.variance(t) * gamma)[:, None, None] 
    return adjustment

  def sde(self, z, t):
    diffusion = self.diffusion(t)
    drift = (self.alpha_t(t) * self.variance(t))[:, None, None] * z + self.drift_adjustment(z, t)
    return drift, diffusion

  # -------- mean, std of the perturbation kernel --------
  def marginal_prob(self, z0, t):
    zeros = torch.zeros_like(t)
    ones = torch.ones_like(t)
    a_0t = self.a_ou(zeros, t)
    a_t1 = self.a_ou(t, ones)
    v_0t = self.v_ou(zeros, t)
    v_t1 = self.v_ou(t, ones)
    denom = v_t1 + v_0t * a_t1**2
    # -------- std --------
    std = torch.sqrt(v_0t * v_t1 / denom)
    # -------- mean --------
    coeff0 = v_t1 * a_0t / denom
    coeff1 = v_0t * a_t1 / denom
    mean = coeff0[:, None, None] * z0 + coeff1[:, None, None] * self.dest 
    return mean, std

  def prior_sampling(self, shape, device='cpu'):
    return torch.randn(*shape, device=device)

  def prior_sampling_sym(self, shape, device='cpu'):
    z = torch.randn(*shape, device=device).triu(1) 
    return z + z.transpose(-1,-2)

  # Compute sigma_t * a_t1 / v_t1
  def loss_coeff(self, t):
    return self.diffusion(t) * self.a_over_v(t)

  def prior_logp(self, z):
    shape = z.shape
    N = np.prod(shape[1:])
    logps = -N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=(1, 2)) / 2.
    return logps

  # Use exponential integrator to compute \int s_{\theta} from time 1- eps to 1?
  def ei_coeff(self, t):
    eps = 1. - t
    # eps = 1.0e-3
    ones = torch.ones_like(t)
    zeros = torch.zeros_like(t)
    delta = self.variance(ones) - self.variance(zeros)
    var = self.variance(t)
    beta = self.beta_t(t)
    gamma = self.a_ou(t, ones)**2 / self.v_ou(t, ones)
    coeff1 = 0.5 * eps * delta / var
    coeff2 = 0.5 * eps * var * gamma
    coeff3 = 0.5 * eps * var * (-2 * self.alpha_t(beta) + gamma * ( 1 - delta / (self.drift_coeff * self.variance(beta)**2) ))

    fin_eps = 1.0e-12
    tmp = 2 * self.variance(ones) / (-delta)
    # log_coeff = torch.log(eps * (tmp + fin_eps) / (fin_eps * (tmp + eps))) 
    log_coeff = 2. / self.variance(ones) #2.0e+4
    coeff_int = 2 * (self.variance(self.beta_t(t)) - self.variance(self.beta_t(ones))) / (self.variance(ones) * (-delta)) * log_coeff
    # import pdb; pdb.set_trace()

    return coeff1, coeff2, coeff3, coeff_int

  def pred(self, drift, z, t):
    var = self.variance(t)
    ones = torch.ones_like(t)
    a_ou = self.a_ou(t, ones)
    gamma_inv = self.v_ou(t, ones) / a_ou**2
    pred = drift - (self.alpha_t(t) * var)[:, None, None] * z
    pred = (pred * (gamma_inv / var)[:, None, None] + z) * a_ou[:, None, None]
    return pred
