import pytorch_lightning as pl
from matplotlib import pyplot as plt
import numpy as np
from torch import nn
import torch 
import abc
import torch
import torch.nn as nn
import torch.nn.functional as F
import tqdm
from scipy import integrate


EPS_SDE= 1e-5

class MLPDenoiser(nn.Module):
    def __init__(self, *, in_dim=3, hid_dim=64, out_dim=2, num_hid_layers=1, dropout=0.1, activation=F.relu):
        super().__init__()
  
        self.num_hid_layers = num_hid_layers
        self.dropout = dropout
        self.activation = activation
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_dim, hid_dim))
        for i in range(num_hid_layers):
            self.layers.append(nn.Linear(hid_dim, hid_dim))
        self.out = nn.Linear(hid_dim, out_dim)

    def forward(self, x, t):
        h = torch.cat([x, t[None,:].T], dim=-1) # concatenate t as an extra feature to x
        for i in range(self.num_hid_layers+1):
            h = self.activation(self.layers[i](h))
            h = F.dropout(h, p=self.dropout, training=self.training)
        return self.out(h)



class ScoreMatching(pl.LightningModule):
    def __init__(self, denoiser, sde):
        super().__init__()
        self.lr = 0.001
        self.sde = sde
        self.batch_size = 64
        self.check_val_every_n_epoch = 10
        self.warmup = 0
        self.denoiser=denoiser
        self.p_steps = 1000
        self.sample_shape = None


    def configure_optimizers(self):
        class scheduler_lambda_function:
            def __init__(self, warm_up):
                self.use_warm_up = True if warm_up > 0 else False
                self.warm_up = warm_up

            def __call__(self, s):
                if self.use_warm_up:
                    if s < self.warm_up:
                        return 100 * (self.warm_up - s) / self.warm_up + 1
                    else:
                        return 1
                else:
                    return 1

        optimizer = torch.optim.Adam(lr=self.lr, params=self.denoiser.parameters()) # in janek's settings, adam uses defaults
        
        scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_function(self.warmup)),
                    'interval': 'epoch'}  # called after each training epoch
        return [optimizer], [scheduler]

    def training_step(self, x_0) -> torch.Tensor:
        if self.sample_shape is None:
            self.sample_shape = x_0.shape

        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "train_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )
 
        return loss  # [1]

    def validation_step(self, x_0, batch_idx) -> torch.Tensor:

        x_0 = x_0
        # sample timesteps
        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )


    def on_validation_epoch_end(self):
        pass


    def noise_x0(self, x_0, t):
        mean, std = self.sde.marginal_prob_terms(x_0, t) # returns mean [n, 1, d, d] and std [n], ignore mean
          
        z = torch.randn_like(x_0)
        x_t = mean + z * std[:, None] # [n,d], must expand std from [n] to [n,d]
        return x_t, z, std

    def predict_score(self, x_t, t: torch.Tensor):
        return self.denoiser(x_t, t)
    
    def noise_and_predict(self, x_0, t):
        x_t, z_true, std = self.noise_x0(x_0, t)
        pred_score = self.predict_score(x_t, t)

        return pred_score, z_true, std
    
    def loss_fn(self, z_true: torch.Tensor, pred_score: torch.Tensor, sdt):

        z_pred = - pred_score * sdt[:, None] # score to the noise

        return torch.nn.functional.mse_loss(
            z_pred, target=z_true)  # Simple MSE loss
            
    @torch.no_grad()
    def run_sampler_forward(self, x):
        x_evo = [x]
        for t in torch.linspace(0, self.sde.T, self.p_steps+1, device=self.device):
            drift, diffusion = self.sde.get_sde_coefficients(x, t)
            x = x + drift * (1/self.p_steps) + diffusion * torch.randn_like(x) / np.sqrt(self.p_steps)
            x_evo.append(x.detach().clone())

        return torch.stack(x_evo)

    @torch.no_grad()   
    def run_sampler_backward(self, x):
        x_evo = [x]
        for t in torch.linspace(0 + EPS_SDE, self.sde.T, self.p_steps+1, device=self.device):
            drift, diffusion = self.sde.get_reverse_sde_coefficients(x, t *torch.ones(x.shape[0],1))
            x =  x + drift * (1/self.p_steps) + diffusion * torch.randn_like(x) / np.sqrt(self.p_steps)
            x_evo.append(x.detach().clone())

        return torch.stack(x_evo)
    

  
class VPSDE():
  def __init__(self, beta_min, beta_max, denoiser, N=1000, Ts=1):
    """Construct a Variance Preserving SDE.
    Args:
    beta_min: value of beta(0)
    beta_max: value of beta(1)
    N: number of discretization steps
    """
    self.beta_0 = beta_min
    self.beta_1 = beta_max
    self.N = N
    self.discrete_betas = torch.linspace(beta_min / N, beta_max / N, N)
    self.alphas = 1. - self.discrete_betas
    self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
    self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
    self.sqrt_1m_alphas_cumprod = torch.sqrt(1. - self.alphas_cumprod)
    self.Ts = Ts
    self.eps = EPS_SDE
    self.denoiser = denoiser

  @property
  def T(self):
    return self.Ts
  
  def perturbation_coefficients(self, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    a_t = torch.exp(log_mean_coeff)
    sigma_t = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return a_t, sigma_t 

  def snr(self, t):
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    alpha_t = torch.exp(log_mean_coeff)
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff))
    return alpha_t**2/std**2

  def get_sde_coefficients(self, x, t):
    """Returns the drift and diffusion terms of the forward SDE, that is -0.5*beta*x_t and sqrt(beta)."""
    beta_t = self.beta_0 + t * (self.beta_1 - self.beta_0)
    drift = -0.5 * beta_t[(...,)+(None,)*len(x.shape[1:])] * x
    diffusion = torch.sqrt(beta_t)
    return drift, diffusion

  def init_score_fn(self, score_fn):
    self.score_fn = score_fn.to("cpu")

  def get_reverse_sde_coefficients(self, x, t):
    """Returns the drift and diffusion terms of the reverse SDE (ODE), that is [-0.5*beta*x_t - beta*score*const] and sqrt(beta) / 0 (sde/ode)."""
    drift, diffusion = self.get_sde_coefficients(x, t) # forward time coeffs
    score = self.denoiser(x, t)

    drift = drift - diffusion[(..., ) + (None, ) * len(x.shape[1:])] ** 2 * score 
    # Set the diffusion function to zero for ODEs.

    return drift, diffusion
    
  def marginal_prob_terms(self, x_0, t): # what is the shape of t?
    """Calculates the mean and std of the probability p(x_T|x_0) as in Ho et al. (2020) discretisation.
    Args:
      x_0: batch of original samples [n, d]
      t: time steps [n]
      Returns:
      mean: mean of the probability p(x_T|x_0) [n, d]
      std: std of the probability p(x_T|x_0) [n]"""
    assert t.shape[0] == x_0.shape[0], "t and x_0 must have the same batch size for mean and std calculation"
    log_mean_coeff = -0.25 * t ** 2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
    mean = torch.exp(log_mean_coeff[(...,)+(None,)*len(x_0.shape[1:])]) * x_0 # probs will break
    std = torch.sqrt(1. - torch.exp(2. * log_mean_coeff)) # [n]
    return mean, std

  def prior_sampling(self, shape):
    return torch.randn(*shape)

  def prior_logp(self, z): # currently will work with 2d toy only
    shape = z.shape
    # N = np.prod(shape[1:])
    N = shape[-1]
    logps = - N / 2. * np.log(2 * np.pi) - torch.sum(z ** 2, dim=-1) / 2. # probs shapes are broken
    return logps
  
