# lightning model for 2D toy example
from typing import List, Sequence, Union

import torch
import numpy as np

import os
import matplotlib.pyplot as plt
from pytorch_lightning import LightningModule
from src.constants import EPS_SDE
from src.diffusion.sampling.samplers import get_sampler

class ToyJanek(LightningModule):
    def __init__(self, config, denoiser, sde):
        super().__init__()
        self.lr = config.optim.lr
        self.beta1 = config.optim.beta1
        self.eps = config.optim.eps
        self.warmup = config.optim.warmup
        # self.device = config["device"]
        self.model = denoiser
        self.sde = sde
        self.batch_size = config.trainer.batch_size

        self.sampler  = get_sampler(config, sde, denoiser=self.model)
        self.check_val_every_n_epoch = config.trainer.check_val_every_n_epoch

    def configure_optimizers(self):
        class scheduler_lambda_function:
            def __init__(self, warm_up):
                self.use_warm_up = True if warm_up > 0 else False
                self.warm_up = warm_up

            def __call__(self, s):
                if self.use_warm_up:
                    if s < self.warm_up:
                        return s / self.warm_up
                    else:
                        return 1
                else:
                    return 1

        optimizer = torch.optim.Adam(lr=self.lr, params=self.model.parameters()) # in janek's settings, adam uses defaults
        
        scheduler = {'scheduler': torch.optim.lr_scheduler.LambdaLR(optimizer,scheduler_lambda_function(self.warmup)),
                    'interval': 'step'}  # called after each training step
                    
        return [optimizer], [scheduler]
    # def configure_optimizers(self):
    #     optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr)
    #     # return optimizer
    #     return {
    #     "optimizer": optimizer,
    #     "lr_scheduler": {
    #         "scheduler": torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=100, threshold=0.001, threshold_mode='rel', cooldown=0, min_lr=1e-5),
    #         "monitor": "val_loss",
    #         "frequency": self.check_val_every_n_epoch,
    #         # If "monitor" references validation metrics, then "frequency" should be set to a
    #         # multiple of "trainer.check_val_every_n_epoch".
    #     },
    # }

    def training_step(self, x_0) -> torch.Tensor:

        x_0 = x_0[0] # for 2D Gauss, x_0 comes as list with one element - this element is our batch
        # sample timesteps
        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "train_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )
 
        return loss  # [1]

    def validation_step(self, x_0, batch_idx) -> torch.Tensor:

        x_0 = x_0[0]
        # sample timesteps
        t = torch.rand(x_0.shape[0], device=self.device) * (self.sde.T - EPS_SDE) + EPS_SDE
        pred_score, z_true, sdt = self.noise_and_predict(x_0, t)

        # Compute loss
        loss = self.loss_fn(z_true, pred_score, sdt)

        self.log(
            "val_loss",
            loss,
            on_step=False,
            on_epoch=True,
            prog_bar=False,
            logger=True,
            batch_size=self.batch_size,
        )

    def on_validation_epoch_end(self):

        x_new, _ = self.sampler.run_sampler(eta=1.0)
        # scatter plot of x_new
        in_train_sample = plt.figure()
        plt.scatter(x_new[:,0].cpu(), x_new[:,1].cpu())  
        # log this figure to logger
        self.logger.experiment.add_figure("samples", in_train_sample, self.current_epoch)

    def noise_x0(self, x_0, t):
        """Add noise to batch of original samples x_0.
        Args:
            x_0: batch of original samples [n, d]
            t: time steps [n] 
        Returns:
            x_t: batch of samples with added noise [n, d]
            z: noise [n, d]"""
        
        mean, std = self.sde.marginal_prob_terms(x_0, t) # returns mean [n,d] and std [n], ignore mean
          
        z = torch.randn_like(x_0)
        x_t = mean + z * std[:, None] # [n,d], must expand std from [n] to [n,d]
        return x_t, z, std

    def predict_score(self, x_t, t: torch.Tensor):
        return self.model(x_t, t)
    
    def noise_and_predict(self, x_0, t):
        """Samples true noise added to x_0.
        Predicts the score for the noisy samples. 
        Each sample gets different value of the timestep.
        Returns:
            pred_score: predicted score for the noisy samples
            z_true: true noise added to the original samples
            std: standard deviation of the perturbation kernel at time t"""
        
        
        x_t, z_true, std = self.noise_x0(x_0, t)
        pred_score = self.predict_score(x_t, t)

        return pred_score, z_true, std
    
    def loss_fn(
        self,
        z_true: torch.Tensor,
        pred_score: torch.Tensor,
        sdt):

        # _, sde = self.sde.marginal_prob_terms(t)
        z_pred = - pred_score * sdt[:,None] # score to the noise

        return torch.nn.functional.mse_loss(
            z_pred, target=z_true)  # Simple MSE loss
    
    @torch.no_grad()
    def sample(self, eta=1.0):
        self.model.eval()
        samples, sampling_info = self.sampler.run_sampler(eta=eta)
        return samples

    def forward_diffusion(self, x):
        noisy_evo = self.sampler.run_sampler_forward(x=x)
        return noisy_evo


