# lightning model for 2D toy example
from typing import List, Sequence, Union

import torch
import numpy as np

import os
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule
from src.constants import EPS_SDE
from src.diffusion.sampling.samplers import get_sampler
import wandb

class ToyDiffusion(LightningModule):
    def __init__(self, config, denoiser, sde):
        super().__init__()
        self.lr = config.lr
        # self.device = config["device"]
        self.model = denoiser
        self.sde = sde
        self.batch_size = config.trainer.batch_size

        self.sampler  = get_sampler(config, sde, denoiser=self.model)
        self.check_val_every_n_epoch = config.trainer.check_val_every_n_epoch
        self.warmup = config.warmup

    def configure_optimizers(self):
        class scheduler_lambda_function:
            def __init__(self, warm_up):
                self.use_warm_up = True if warm_up > 0 else False
                self.warm_up = warm_up

            def __call__(self, s):
                if self.use_warm_up:
                    if s < self.warm_up:
                        return 100*(self.warm_up - s) / self.warm_up + 1
                    else:
                        return 1
                else:
                    return 1

        optimizer = torch.optim.Adam(lr=self.lr, params=self.model.parameters()) # in janek's settings, adam uses defaults
        
        scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer, scheduler_lambda_function(self.warmup)),
                    'interval': 'epoch'}  # called after each training epoch
        return [optimizer], [scheduler]

    def training_step(self, x_0) -> torch.Tensor:

        x_0 = x_0[0] # for 2D Gauss, x_0 comes as list with one element - this element is our batch
        # sample timesteps

        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        if self.sde.__class__.__name__ == "PinnedBrownSDE":
            t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE)
        
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "train_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )
 
        return loss  # [1]

    def validation_step(self, x_0, batch_idx) -> torch.Tensor:

        x_0 = x_0[0]
        # sample timesteps
        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )


    def on_validation_epoch_end(self):

        if self.current_epoch % 10 == 0:
            x_new, _ = self.sampler.run_sampler(eta=1.0)

            # Scatter plot of x_new
            in_train_sample = plt.figure()
            plt.scatter(x_new[:, 0].cpu(), x_new[:, 1].cpu())  

            # Log the scatter plot as an image to wandb
            self.logger.experiment.log({"samples_epoch_" + str(self.current_epoch): wandb.Image(in_train_sample)}, step=self.current_epoch)


    
    def noise_x0(self, x_0, t):
        """Add noise to batch of original samples x_0.
        Args:
            x_0: batch of original samples [n, d]
            t: time steps [n] 
        Returns:
            x_t: batch of samples with added noise [n, d]
            z: noise [n, d]"""
        
        mean, std = self.sde.marginal_prob_terms(x_0, t) # returns mean [n,d] and std [n], ignore mean
          
        z = torch.randn_like(x_0)
        x_t = mean + z * std[:, None] # [n,d], must expand std from [n] to [n,d]
        return x_t, z, std

    def predict_score(self, x_t, t: torch.Tensor):
        return self.model(x_t, t)
    
    def noise_and_predict(self, x_0, t):
        """Samples true noise added to x_0.
        Predicts the score for the noisy samples. 
        Each sample gets different value of the timestep.
        Returns:
            pred_score: predicted score for the noisy samples
            z_true: true noise added to the original samples
            std: standard deviation of the perturbation kernel at time t"""
        
        
        x_t, z_true, std = self.noise_x0(x_0, t)
        pred_score = self.predict_score(x_t, t)

        return pred_score, z_true, std
    
    def loss_fn(
        self,
        z_true: torch.Tensor,
        pred_score: torch.Tensor,
        sdt):

        # _, sde = self.sde.marginal_prob_terms(t)
        z_pred = - pred_score * sdt[:,None] # score to the noise

        return torch.nn.functional.mse_loss(
            z_pred, target=z_true)  # Simple MSE loss
    
    @torch.no_grad()
    def sample(self, eta=1.0):
        self.model.eval()
        samples, sampling_info = self.sampler.run_sampler(eta=eta)
        return samples
    
    def forward_diffusion(self, x):
        noisy_evo = self.sampler.run_sampler_forward(x=x)
        return noisy_evo

        


