import torch
from argparse import ArgumentParser
from models.ncsn_models import SongUNet as NCSN
import pytorch_lightning as pl
from data.metrics import psnr, LPIPS, ssim
from data.operators import create_operator, create_noise_schedule

LOSS_KEYS = [
            "mse_loss", 
            "weighted_mse_loss", 
            "mse_x_loss", 
            "weighted_mse_x_loss",
            "inc_recon_loss", 
            "weighted_inc_recon_loss",
            "psnr_loss",
            "ssim_loss", 
            "lpips_loss",
]

class NCSN_Module(pl.LightningModule):

    def __init__(
        self,
        dt,
        operator_config,
        noise_config,
        loss_type,
        max_epochs,
        lr,
        lr_step_size,
        lr_gamma,
        residual_prediction=True,
        weight_decay=0.0,
        logger_type='wandb',
        full_val_only_last_epoch=True,
        val_dt=None,
        num_log_images=10,
        **kwargs,
    ):
        super().__init__()
        self.save_hyperparameters()
        self.full_val_only_last_epoch = full_val_only_last_epoch
        self.logger_type = logger_type
        self.num_log_images = num_log_images
        if self.logger_type == 'wandb':
            global wandb
            import wandb
            
        self.dt = dt
        self.val_dt = self.dt if val_dt is None else val_dt

        self.fwd_operator = create_operator(operator_config)
        self.noise_schedule = create_noise_schedule(noise_config)
        self.residual_prediction = residual_prediction
        self.lr = lr
        self.lr_step_size = lr_step_size
        self.lr_gamma = lr_gamma
        self.weight_decay = weight_decay
        
        self.loss_type = loss_type
        
        self.max_epochs = max_epochs

        self.denoiser_fn = NCSN(
            img_resolution=256,                 # Image resolution at input/output.
            in_channels=3,                      # Number of color channels at input.
            out_channels=3,                     # Number of color channels at output.
            label_dim           = 0,            # Number of class labels, 0 = unconditional.
            augment_dim         = 0,            # Augmentation label dimensionality, 0 = no augmentation.

            model_channels      = 128,          # Base multiplier for the number of channels.
            channel_mult        = [1, 1, 2, 2, 2, 2, 2],    # Per-resolution multipliers for the number of channels.
            channel_mult_emb    = 4,            # Multiplier for the dimensionality of the embedding vector.
            num_blocks          = 2,            # Number of residual blocks per resolution.
            attn_resolutions    = [16],         # List of resolutions with self-attention.
            dropout             = 0.10,         # Dropout probability of intermediate activations.
            label_dropout       = 0,            # Dropout probability of class labels for classifier-free guidance.

            embedding_type      = 'fourier', # Timestep embedding type: 'positional' for DDPM++, 'fourier' for NCSN++.
            channel_mult_noise  = 2,            # Timestep embedding size: 1 for DDPM++, 2 for NCSN++.
            encoder_type        = 'residual',   # Encoder architecture: 'standard' for DDPM++, 'residual' for NCSN++.
            decoder_type        = 'standard',   # Decoder architecture: 'standard' for both DDPM++ and NCSN++.
            resample_filter     = [1,3,3,1],        # Resampling filter: [1,1] for DDPM++, [1,3,3,1] for NCSN++.
    )
        
        # move VGG model to GPU for LPIPS metric eval 
        self.lpips = LPIPS('vgg')
        
    def get_prediction(self, y, noise_labels):
        x0 = self.denoiser_fn(y, noise_labels, class_labels=None)
        if self.residual_prediction:
            x0 += y
        return x0
        
    def get_losses(self, batch):
        x0 = self.get_prediction(y=batch['degraded_noisy'], noise_labels=batch['noise_std'])
        b = x0.shape[0]
        weighted_mse_loss = 0
        mse_loss = 0
        weighted_mse_x_loss = 0
        mse_x_loss = 0
        psnr_loss = 0
        ssim_loss = 0
        lpips_loss = 0
        inc_recon_loss = 0
        weighted_inc_recon_loss = 0
        for i in range(b): # Replace with batch implementation!
            # get degraded reconstructions
            y_hat = self.fwd_operator(x0[i], batch['t_this'][i])
            y_hat_next = self.fwd_operator(x0[i], batch['t_next'][i])
            y = batch['degraded_this'][i]
            y_next = batch['degraded_next'][i]
            
            # get mse losses
            error = (y - y_hat).pow(2).sum() 
            mse_loss += error
            weighted_mse_loss += error * (1 / batch['noise_std'][i]**2)
            
            # get mse_x losses
            error_x = (x0[i] - batch['clean'][i]).pow(2).sum()
            mse_x_loss += error_x
            weighted_mse_x_loss += error_x * (1 / batch['noise_std'][i]**2)
            
            # get incremental reconstruction loss
            error_next = (y_next - y_hat_next).pow(2).sum() 
            inc_recon_loss += error_next
            inc_recon_weight = (1 / batch['noise_std'][i]**2)
            weighted_inc_recon_loss += error_next * inc_recon_weight

            # get image quality losses
            psnr_loss += psnr(x0[i], batch['clean'][i])
            ssim_loss += ssim(x0[i], batch['clean'][i])            
            lpips_loss += self.lpips(x0[i], batch['clean'][i]) 
        mse_loss /= b
        weighted_mse_loss /= b
        mse_x_loss /= b
        weighted_mse_x_loss /= b
        psnr_loss /= b
        ssim_loss /= b
        lpips_loss /= b
        inc_recon_loss /= b
        weighted_inc_recon_loss /= b
        return {
            'prediction': x0,
            'mse_loss': mse_loss, 
            'weighted_mse_loss': weighted_mse_loss, 
            'mse_x_loss': mse_x_loss, 
            'weighted_mse_x_loss': weighted_mse_x_loss, 
            'inc_recon_loss': inc_recon_loss, 
            'weighted_inc_recon_loss': weighted_inc_recon_loss, 
            'psnr_loss': psnr_loss,
            'ssim_loss': ssim_loss,
            'lpips_loss': lpips_loss,
        }
    
    def inference(self, degraded_noisy, degr_update_method='naive'):
        # Simple inference for validation. Proper inference with more settings in reverse_diffusion.py
        num_steps = int(1.0 / self.dt)
        y = degraded_noisy.clone()
        for i in range(num_steps):
            # get new variables
            t_this = torch.tensor(1.0 - self.dt * i).to(y.device)
            t_next = t_this - self.dt
            std_this = torch.tensor(self.noise_schedule.get_std(t_this)).to(y.device)
            std_next = torch.tensor(self.noise_schedule.get_std(t_next)).to(y.device)
            x0_pred = self.get_prediction(y=y, noise_labels=std_this.view(1, ))
            
            # diffusion term
            diffusion = torch.randn_like(degraded_noisy) * torch.sqrt(std_this**2 - std_next**2)
            
            # denoising term
            y_hat = self.fwd_operator(x0_pred[0], t_this).unsqueeze(0)
            denoising = (y_hat - y)
            denoising *= (std_this**2 - std_next**2) / std_this**2
            
            # degradation update term
            if degr_update_method=='naive':
                degr_update = self.fwd_operator(x0_pred[0], t_next).unsqueeze(0) - y_hat
            elif degr_update_method=='taylor':
                if t_this + self.dt >= 1.0: # In the first update do naive approach
                    degr_update = self.fwd_operator(x0_pred[0], t_next).unsqueeze(0) - y_hat
                else:
                    degr_update = y_hat - self.fwd_operator(x0_pred[0], t_this + self.dt).unsqueeze(0)
                    
            # update y
            y = y + degr_update + denoising + diffusion
        return y           
                                                     
    def training_step(self, batch, batch_idx, optimizer_idx=0):
        losses = self.get_losses(batch)

        if batch_idx < self.num_log_images and self.global_rank == 0:
            noised_im = batch['degraded_noisy'][0].unsqueeze(0).detach()
            denoised_im = losses['prediction'][0].unsqueeze(0).detach()
            self.log_image(f"train/{batch_idx}/target", batch["clean"][0].unsqueeze(0))
            self.log_image(f"train/{batch_idx}/degraded_noisy", noised_im)
            self.log_image(f"train/{batch_idx}/prediction", denoised_im)

        self.log_losses(losses, 'train')
        return losses[self.loss_type]
        
    def validation_step(self, batch, batch_idx):
        losses = self.get_losses(batch)
        
        if batch_idx == 0:
            # run full validation on a single image
            degraded_y = self.fwd_operator(batch['clean'][0], torch.ones_like(batch['t_this'][0]))
            z, _ = self.noise_schedule(1.0, batch['clean'][0].shape)
            degraded_noisy = degraded_y + z.to(batch['clean'][0].device)
            degraded_noisy = degraded_noisy.unsqueeze(0)

            recon_naive = self.inference(degraded_noisy, degr_update_method='naive')
            recon_taylor = self.inference(degraded_noisy, degr_update_method='taylor')
            self.log_image(f"val_recon/{batch_idx}/target", batch["clean"][0].unsqueeze(0))
            self.log_image(f"val_recon/{batch_idx}/degraded_noisy", degraded_noisy)
            self.log_image(f"val_recon/{batch_idx}/recon_naive", recon_naive)
            self.log_image(f"val_recon/{batch_idx}/recon_taylor", recon_taylor)
            degraded_y_naive =  self.fwd_operator(recon_naive[0], torch.ones_like(batch['t_this'][0]))
            degraded_y_taylor =  self.fwd_operator(recon_taylor[0], torch.ones_like(batch['t_this'][0]))
            mse_dc_naive = (degraded_y - degraded_y_naive).pow(2).sum() 
            mse_dc_taylor = (degraded_y - degraded_y_taylor).pow(2).sum() 
            self.log("val_recon/naive/mse_dc", mse_dc_naive, on_epoch=True)
            self.log("val_recon/taylor/mse_dc", mse_dc_taylor, on_epoch=True)
            self.log("val_recon/naive/psnr_x", psnr(recon_naive[0], batch["clean"][0]), on_epoch=True)
            self.log("val_recon/taylor/psnr_x", psnr(recon_taylor[0], batch["clean"][0]), on_epoch=True)
            self.log("val_recon/naive/ssim_x", ssim(recon_naive[0], batch["clean"][0]), on_epoch=True)
            self.log("val_recon/taylor/ssim_x", ssim(recon_taylor[0], batch["clean"][0]), on_epoch=True)
            self.log("val_recon/naive/lpips", self.lpips(recon_naive[0], batch["clean"][0]), on_epoch=True)
            self.log("val_recon/taylor/lpips", self.lpips(recon_taylor[0], batch["clean"][0]), on_epoch=True)
        
        if batch_idx < self.num_log_images and self.global_rank == 0:
            noised_im = batch['degraded_noisy'][0].unsqueeze(0).detach()
            denoised_im = losses['prediction'][0].unsqueeze(0).detach()
            self.log_image(f"val/{batch_idx}/target", batch["clean"][0].unsqueeze(0))
            self.log_image(f"val/{batch_idx}/degraded_noisy", noised_im)
            self.log_image(f"val/{batch_idx}/prediction", denoised_im)

        self.log_losses(losses, 'val')
            
        return {"val_loss": losses[self.loss_type]}
 
    def test_step(self, batch, batch_idx):
        pass
    
    def log_losses(self, losses, folder_name):
        for key in LOSS_KEYS:
            self.log('{}/{}'.format(folder_name, key), losses[key])
            
    
    def configure_optimizers(self):
        optims = []
        scheds = []
        optim_denoiser = torch.optim.Adam(
            self.denoiser_fn.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        denoiser_scheduler = torch.optim.lr_scheduler.StepLR(
            optim_denoiser, self.lr_step_size, self.lr_gamma
        )
        optims.append(optim_denoiser)
        scheds.append(denoiser_scheduler)
        return optims, scheds

    def log_image(self, name, image):
        if self.logger_type == 'wandb':
            # wandb logging
            self.logger.experiment.log({name:  wandb.Image(image)})
        else:
            # tensorboard logging (default)
            self.logger.experiment.add_image(name, image, global_step=self.global_step)
        
    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        """
        Define parameters that only apply to this model
        """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # network params
        parser.add_argument(
            "--num_steps",
            default=50,
            type=int,
            help="Number of degradation steps for the reverse process.",
        )
        parser.add_argument(
            "--num_val_steps",
            default=None,
            type=int,
            help="Number of degradation steps for validation.",
        )
        parser.add_argument(
            '--no_residual_prediction', 
            default=False,   
            action='store_true',   
            help='If set, the model predicts the clean image, otherwise the residual image (default).'
        )

        # training params (opt)
        parser.add_argument(
            "--loss_type", 
            default="weighted_mse_loss", 
            type=str, 
            help="Loss used to train the model."
        )
        
        parser.add_argument(
            "--lr", 
            default=0.0001, 
            type=float, 
            help="Adam learning rate"
        )
        parser.add_argument(
            "--lr_step_size",
            default=400000000,
            type=int,
            help="Epoch at which to decrease step size",
        )
        parser.add_argument(
            "--lr_gamma", 
            default=0.1, 
            type=float, 
            help="Amount to decrease step size"
        )
        parser.add_argument(
            "--weight_decay",
            default=0.0,
            type=float,
            help="Strength of weight decay regularization",
        )
        
        # logging related
        parser.add_argument(
            "--num_log_images",
            default=8,
            type=int,
            help="Number of images to log (both train and val dataset searately).",
        )

        return parser