from argparse import ArgumentParser
import copy
import contextlib
import torch
from typing import Dict, Optional, Sequence, Tuple, List, Union, NamedTuple
from argparse import ArgumentParser
from ldm.modules.diffusionmodules.model import Encoder
from pl_modules.utils import load_config_from_yaml
from collections import OrderedDict
from collections import defaultdict
import numpy as np
import pathlib
import pytorch_lightning as pl
import wandb
from ldm.util import instantiate_from_config
from data.operators import create_operator, create_noise_schedule
from data.metrics import ssim

NUM_EVAL_LEVELS = 10
rescale_to_zero_one = lambda x: (x + 1.) / 2.
rescale_to_minusone_one = lambda x: x * 2. - 1.

class SevEncoder(torch.nn.Module):
    def __init__(
            self,
            config,
        ):
        super(SevEncoder, self).__init__()
        ddconfig = config['params']['ddconfig']
        self.embed_dim = config['params']['embed_dim']
        embed_hw = ddconfig['resolution'] // 2 ** (len(ddconfig['ch_mult']) - 1)
        self.encoder = Encoder(**ddconfig)
        ch_mult = 2 if ddconfig["double_z"] else 1
        self.quant_conv = torch.nn.Conv2d(ch_mult * ddconfig["z_channels"], ch_mult * self.embed_dim, 1)
        self.encoder_type = config['target']
        self.sigma_max = 1.0
        if self.encoder_type in ['ldm.models.autoencoder.VQModel', 'ldm.models.autoencoder.VQModelInterface']:
            # need to map latent to variances
            self.std_conv = torch.nn.Conv2d(ddconfig["z_channels"], self.embed_dim, 1)
        self.num_resolutions = self.encoder.num_resolutions
        
    def cov_to_var(self, cov):
        b = cov.shape[0]
        return cov.view(b, -1).mean(dim=1)

    def get_embedding(self, x):
        h = self.encoder(x)
        z = self.quant_conv(h)
        if self.encoder_type == 'ldm.models.autoencoder.AutoencoderKL':
            mean, std = torch.chunk(z, 2, dim=1)            
        elif self.encoder_type in ['ldm.models.autoencoder.VQModel', 'ldm.models.autoencoder.VQModelInterface']:
            mean, std = z, self.std_conv(z)
        var = std.pow(2)    
        var_single = self.cov_to_var(var)
        return mean, var_single
    
    def __call__(self, x):
        return self.get_embedding(x)     

class SeverityEncoderModule(pl.LightningModule):

    def __init__(
        self,
        operator_config,
        noise_config,
        ldm_model_ckpt_path,
        ldm_model_config_path,
        sev_encoder_ckpt_path,
        sev_encoder_config_path,
        pretrained_encoder_ckpt_path,
        pretrained_encoder_config_path,
        sigma_reg,
        img_space_reg,
        lr,
        lr_step_size,
        lr_gamma,
        lr_milestones,
        weight_decay=0.0,
        logger_type='wandb',
        **kwargs,
    ): 
        super().__init__()
        self.save_hyperparameters()
        self.logger_type = logger_type
        if self.logger_type == 'wandb':
            global wandb
            import wandb
        self.fwd_operator = create_operator(operator_config)
        if noise_config is None:
            self.noise_schedule = None
            self.fwd_sigma_max = 0.0
        else:
            self.noise_schedule = create_noise_schedule(noise_config)    
            self.fwd_sigma_max = noise_config['sigma_max']
            
        self.lr = lr
        self.lr_step_size = lr_step_size
        self.lr_milestones = lr_milestones
        self.lr_gamma = lr_gamma
        self.weight_decay = weight_decay
        
        # Check if given checkpoint-config combination is valid
        if ldm_model_ckpt_path is not None:
            assert ldm_model_config_path is not None
            print("Setting up encoders/decoders from pretrained LDM.")
        else:
            assert self.pretrained_autoencoder_config is not None and self.pretrained_encoder_ckpt_path is not None
            print('Loading ground truth encoder from pretrained autoencoder.')
                
        # Set up pretrained encoder
        self.ldm_model_ckpt_path = ldm_model_ckpt_path
        self.ldm_model_config_path = ldm_model_config_path
        if ldm_model_ckpt_path is None:
            print('Initializing pretrained autoencoder from config {} and checkpoint {}.'.format(self.pretrained_autoencoder_config, self.pretrained_encoder_ckpt_path))
            self.pretrained_autoencoder_config = load_config_from_yaml(pretrained_encoder_config_path)['model']
            self.pretrained_encoder_ckpt_path = pretrained_encoder_ckpt_path
            self.pretrained_autoencoder = instantiate_from_config(self.pretrained_autoencoder_config)
            self.pretrained_autoencoder = self.pretrained_autoencoder.load_from_checkpoint(checkpoint_path=self.pretrained_encoder_ckpt_path, 
                                **self.pretrained_autoencoder_config['params'])
            self.pretrained_encoder_type = self.pretrained_autoencoder_config['target']
        else:
            print('Initializing autoencoder from pretrained LDM with config {} and checkpoint {}.'.format(self.ldm_model_config_path, self.ldm_model_ckpt_path))
            ldm_config = load_config_from_yaml(self.ldm_model_config_path)
            ldm_checkpoint = torch.load(self.ldm_model_ckpt_path, map_location='cpu')['state_dict']
            autoencoder_checkpoint = {k.replace('first_stage_model.', ''): v for k,v in ldm_checkpoint.items() if 'first_stage_model.' in k}
            self.pretrained_autoencoder = instantiate_from_config(ldm_config['model']['params']['first_stage_config'])
            self.pretrained_autoencoder.load_state_dict(autoencoder_checkpoint)
            self.pretrained_encoder_type = ldm_config['model']['params']['first_stage_config']['target']
        self.pretrained_autoencoder.eval()
        for param in self.pretrained_autoencoder.parameters(): # Freeze ae
            param.requires_grad = False
        
        
        # Set up encoder
        if ldm_model_ckpt_path is None:
            self.sev_encoder_config_path = sev_encoder_config_path
            self.sev_encoder_config = load_config_from_yaml(sev_encoder_config_path)['model']
            self.sev_encoder_ckpt_path = sev_encoder_ckpt_path
            self.encoder = SevEncoder(self.sev_encoder_config)
            if self.sev_encoder_ckpt_path is not None:      
                print('Initializing Severity Encoder with pretrained model from ', self.ldm_model_ckpt_path)
                checkpoint = torch.load(self.sev_encoder_ckpt_path, map_location='cpu')['state_dict']
                checkpoint = {k.replace('encoder.', ''): v for k,v in checkpoint.items() if 'encoder.' in k}
                self.encoder.encoder.load_state_dict(checkpoint)
        else:
            print('Initializing Severity Encoder with pretrained LDM encoder from ', self.ldm_model_config_path)
            self.sev_encoder_config = load_config_from_yaml(self.ldm_model_config_path)['model']['params']['first_stage_config']
            self.encoder = SevEncoder(self.sev_encoder_config)
            encoder_checkpoint = {k.replace('first_stage_model.encoder.', ''): v for k,v in ldm_checkpoint.items() if 'first_stage_model.encoder.' in k}
            quant_conv_checkpoint = {k.replace('first_stage_model.quant_conv.', ''): v for k,v in ldm_checkpoint.items() if 'first_stage_model.quant_conv.' in k}
            self.encoder.encoder.load_state_dict(encoder_checkpoint)
            self.encoder.quant_conv.load_state_dict(quant_conv_checkpoint)
            del ldm_checkpoint, autoencoder_checkpoint, encoder_checkpoint, quant_conv_checkpoint
            
        # Regularization
        self.sigma_reg = sigma_reg
        self.img_space_reg = img_space_reg
    
    def forward(self, x):
        return self.encoder(x)
    
    def __call__(self, x):
        return self.forward(x)
    
    def encode(self, x, get_var=True):
        z_mean, z_var  = self.encoder(x)
        if get_var:
            return z_mean, z_var
        else:
            return z_mean
        
    def decode(self, x, force_not_quantize=False):
        force_not_quantize = True
        if self.pretrained_encoder_type in ['ldm.models.autoencoder.VQModel', 'ldm.models.autoencoder.VQModelInterface']:
            # also go through quantization layer
            if not force_not_quantize:
                quant, emb_loss, info = self.pretrained_autoencoder.quantize(x)
            else:
                quant = x
            quant2 = self.pretrained_autoencoder.post_quant_conv(quant)
            dec = self.pretrained_autoencoder.decoder(quant2)
            return dec
        elif self.pretrained_encoder_type == 'ldm.models.autoencoder.AutoencoderKL':
            return self.pretrained_autoencoder.decode(x)
        else:
            raise ValueError('Unknown model type.')
        
    def get_z0(self, x):
        if self.pretrained_encoder_type == 'ldm.models.autoencoder.AutoencoderKL':
            return self.pretrained_autoencoder.encode(x).mode().detach()
        elif self.pretrained_encoder_type in ['ldm.models.autoencoder.VQModel', 'ldm.models.autoencoder.VQModelInterface']:
            h = self.pretrained_autoencoder.encoder(x)
            h = self.pretrained_autoencoder.quant_conv(h)
            return h.detach()
        else:
            raise ValueError('Unknown encoder type')
        
    def get_loss(self, batch):
        b = batch['clean'].shape[0]
        d_img = batch['clean'].view(b, -1).shape[1]
        z0 = self.get_z0(batch['clean'])
        z_mean, z_var = self.encoder(batch['degraded_noisy'])
        mean_term = (z_mean - z0).pow(2).view(b, -1).sum(1)
        mean_term_scaled = mean_term / (z_var + 1e-9)
        var_term = z_var
        loss =  mean_term_scaled + self.sigma_reg * var_term
        
        # Image space loss
        if self.img_space_reg > 0.0:
            recon_ssim = 0
            x_pred = self.decode(z_mean)
            img_space_loss = (x_pred - batch['clean']).view(b, -1).pow(2).sum()
            for i in range(b):
                recon_ssim += ssim(rescale_to_zero_one(x_pred[i]),
                                   rescale_to_zero_one(batch['clean'][i]),
                                  )
        else:
            img_space_loss = 0
            recon_ssim = 0
            
        recon_ssim /= b
        
        # Discrepancy metrics
        e_i = (z_mean - z0).view(b, -1)
        d = e_i.shape[1] # dimension of latent vec
        mu_i = e_i.sum(1) / d
        var_i = (e_i - mu_i.unsqueeze(1)).pow(2).sum(1) / (d - 1)
        var_discrep_sq = (var_i - z_var).pow(2)
        kl_loss = mean_term / (var_i + 1e-9) + d * z_var / (var_i + 1e-9) - d + d * (torch.log(var_i) - torch.log(z_var))
        naive_loss = 1 / d * mean_term + self.sigma_reg * var_discrep_sq + self.img_space_reg * 1 / d_img * img_space_loss
        mean_discrep = torch.abs(mu_i - 0.0).mean()
        var_discrep = torch.abs(var_i - z_var).mean()
        var_discrep_relative = torch.abs((var_i - z_var)/var_i).mean()
        return {'loss': loss.mean(), 
                'kl_loss': kl_loss.mean(),
                'naive_loss': naive_loss.mean(),
                'mean_term': mean_term.mean(), 
                'mean_term_scaled': mean_term_scaled.mean(), 
                'var_term': var_term.mean(),
                'var_bars': var_i.mean(),
                'var_discrep_relative': var_discrep_relative,
                'var_discrep_sq': var_discrep_sq.mean(),
                'mean_discrep': mean_discrep,
                'var_discrep': var_discrep,
                'img_space_loss': img_space_loss.mean(),
                'recon_ssim': recon_ssim
               }
                                                     
    def training_step(self, batch, batch_idx):
        losses = self.get_loss(batch)
        for k, v in losses.items():
            self.log('train/{}'.format(k), losses[k])
        return losses['naive_loss']
        
    def validation_step(self, batch, batch_idx):
        val_losses = self.get_loss(batch)
        for k, v in val_losses.items():
            self.log('val/{}'.format(k), val_losses[k])
        ord_loss, ord_acc = self.eval_ordering(batch)
        self.log('val/ord_loss', ord_loss)
        self.log('val/ord_acc', ord_acc)
        if batch_idx == 0:
            b = batch['clean'].shape[0]
            _, z_var = self.encoder(batch['degraded_noisy'])
            for i in range(b):
                self.log_image('val/images/img_{}'.format(i), batch["degraded_noisy"][i].unsqueeze(0), 'var_pred: {}, t_gt: {}'.format(str(z_var[i].detach().cpu().numpy()),str(batch["t"][i].detach().cpu().numpy())))
                
    def test_step(self, batch, batch_idx):
        pass
    
    def eval_ordering(self, batch):
        x0 = rescale_to_zero_one(batch['clean'])
        b = x0.shape[0]
        var_preds = torch.zeros(b, NUM_EVAL_LEVELS)
        ts = torch.linspace(0.0, 1.0, NUM_EVAL_LEVELS)
        for i in range(NUM_EVAL_LEVELS):
            y = self.fwd_operator(x0, torch.tensor(ts[i]).to(x0.device)) 
            if self.noise_schedule is not None:
                z, _ = self.noise_schedule(ts[i], y.shape)
                y += z.to(y.device)
            y = rescale_to_minusone_one(y)
            var_preds[:, i] = self.encoder(y)[1]
            self.log_image('val/images/seq/img_{}'.format(i), y[0].unsqueeze(0), 'var_pred: {}, t_gt: {}'.format(str(var_preds[0, i].detach().cpu().numpy()),str(ts[i].detach().cpu().numpy())))
            
        ord_true = self.ordering_mx(torch.linspace(0.0, 1.0, NUM_EVAL_LEVELS).repeat(b, 1))
        ord_pred = self.ordering_mx(var_preds)
        ord_loss = 0.5 * (ord_true-ord_pred).pow(2).mean()
        ord_acc = self.ordering_acc(ord_pred, ord_true).mean()
        return ord_loss, ord_acc
            
    def ordering_acc(self, ordering_mx, gt_mx):
        num_vals = (NUM_EVAL_LEVELS - 1) * NUM_EVAL_LEVELS / 2 
        return torch.triu(torch.where(ordering_mx == gt_mx, 1.0, 0.0), diagonal=1).sum(dim=[1, 2]) / num_vals   
        
    def ordering_mx(self, vals):
        # Assume vals shape (b, num_levels)
        # Compare each pair of output entries, assign 0/1 to each pair if less/greater than.
        assert vals.shape[1] == NUM_EVAL_LEVELS
        x = vals.unsqueeze(2).expand(vals.size(0), NUM_EVAL_LEVELS, NUM_EVAL_LEVELS)
        xT = vals.unsqueeze(1).expand(vals.size(0), NUM_EVAL_LEVELS, NUM_EVAL_LEVELS)
        ord_mx = torch.where(x - xT > 0, 0.0, 1.0)
        mask = 1.0 - torch.eye(NUM_EVAL_LEVELS, NUM_EVAL_LEVELS, dtype=vals.dtype, device=vals.device)
        ord_mx = ord_mx * mask  # Zero out diagonals to avoid numerical errors when comparing the same float to itself
        return ord_mx
    
    @staticmethod
    def triu_values(x):
        if len(x.shape) == 2:
            mask = torch.ones(x.shape[0], x.shape[0])
            return x[mask.triu(diagonal=1)==1]
        elif len(x.shape) == 3:
            mask = torch.ones(x.shape[1], x.shape[1])
            return x[:,mask.triu(diagonal=1)==1]
        else:
            raise ValueError("Can't deal with this shape {}".format(x.shape))
    
    def diff_reg(self, vals):
        # Assume vals shape (b, 1)
        b = vals.shape[0]
        x = vals.view(b, 1).expand(vals.size(0), vals.size(0))
        xT = vals.view(1, b).expand(vals.size(0), vals.size(0))
        d = torch.abs(x - xT)
        d = self.triu_values(d).min()
        return 1 / (d + 1e-9)
    
    def configure_optimizers(self):

        optim = torch.optim.Adam(
            self.parameters(), lr=self.lr, weight_decay=self.weight_decay
        )
        if self.lr_step_size is not None:
            scheduler = torch.optim.lr_scheduler.StepLR(
                optim, self.lr_step_size, self.lr_gamma
            )
            return [optim], [scheduler]
        elif self.lr_milestones is not None:
            scheduler = torch.optim.lr_scheduler.MultiStepLR(
                optim, self.lr_milestones, self.lr_gamma
            )
            return [optim], [scheduler]
        return optim
        
    
    def log_image(self, name, image, caption=None):
        if self.logger_type == 'wandb':
            # wandb logging
            self.logger.experiment.log({name:  wandb.Image(image, caption=caption)})
        else:
            # tensorboard logging (default)
            self.logger.experiment.add_image(name, image, global_step=self.global_step)
        
    @staticmethod
    def add_model_specific_args(parent_parser):  # pragma: no-cover
        """
        Define parameters that only apply to this model
        """
        parser = ArgumentParser(parents=[parent_parser], add_help=False)

        # configs/model setup
        parser.add_argument(
            "--ldm_model_config_path", 
            type=str, 
            default=None,
            help="Config file for LDM."
        )
        parser.add_argument(
            "--ldm_model_ckpt_path", 
            type=str, 
            default=None,
            help="Path to pretrained LDM to initialize sev encoder."
        )
        parser.add_argument(
            "--sev_encoder_config_path", 
            type=str, 
            default=None,
            help="Config file for encoder arch."
        )
        parser.add_argument(
            "--sev_encoder_ckpt_path", 
            type=str, 
            default=None,
            help="Path to pretrained encoder to initialize sev encoder with."
        )
        parser.add_argument(
            "--pretrained_encoder_ckpt_path", 
            type=str, 
            default=None,
            help="Path to pretrained AE model checkpoint."
        )
        parser.add_argument(
            "--pretrained_encoder_config_path", 
            type=str, 
            default=None,
            help="Config file for pretrained autoencoder"
        )

        # training params (opt)       
        parser.add_argument(
            "--lr", 
            default=0.0001, 
            type=float, 
            help="Adam learning rate"
        )
        parser.add_argument(
            "--lr_step_size",
            default=None,
            type=int,
            help="Number of epochs to decrease step size",
        )
        parser.add_argument(
            "--lr_milestones",
            default=None,
            type=int,
            nargs='+',
            help="List of epochs to decrease step size",
        )
        parser.add_argument(
            "--lr_gamma", 
            default=0.1, 
            type=float, 
            help="Amount to decrease step size"
        )
        parser.add_argument(
            "--weight_decay",
            default=0.0,
            type=float,
            help="Strength of weight decay regularization",
        )
        parser.add_argument(
            "--sigma_reg", 
            default=0.0,
            type=float, 
            help="Weight of sigma regularization."
        )
        parser.add_argument(
            "--img_space_reg", 
            default=0.0,
            type=float, 
            help="Weight of image domain regularization."
        )

        return parser