import torch
import pytorch_lightning as pl
import torch.nn.functional as F
from torch.optim.lr_scheduler import LambdaLR
import numpy as np
from packaging import version
from ldm.modules.ema import LitEma
from contextlib import contextmanager

from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer

from ldm.modules.diffusionmodules.model import *

from ldm.util import instantiate_from_config


class VQDiGAN(pl.LightningModule):
    def __init__(self,
                 ddconfig,
                 lossconfig,
                 n_embed,
                 embed_dim,
                 ckpt_path=None,
                 ignore_keys=[],
                 image_key="image",
                 colorize_nlabels=None,
                 monitor=None,
                 batch_resize_range=None,
                 scheduler_config=None,
                 lr_g_factor=1.0,
                 remap=None,
                 sane_index_shape=False, # tell vector quantizer to return indices as bhw
                 use_ema=False
                 ):
        super().__init__()
        self.embed_dim = embed_dim # 3
        self.n_embed = n_embed # 8192
        self.image_key = image_key # 'image'
        num_classes_encoder = ddconfig.pop("num_classes_encoder")
        num_classes_decoder = ddconfig.pop("num_classes_decoder")
        ddconfig["num_classes"] = num_classes_encoder
        self.encoder = MultiClass_DiEncoder(**ddconfig)
        ddconfig["num_classes"] = num_classes_decoder
        self.decoder = MultiClass_DiDecoderWithResidual(**ddconfig)
        self.loss = instantiate_from_config(lossconfig)
        self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25,
                                        remap=remap,
                                        sane_index_shape=sane_index_shape)
        self.quant_conv = torch.nn.Conv3d(ddconfig["z_channels"], embed_dim, 1)
        self.post_quant_conv = torch.nn.Conv3d(embed_dim, ddconfig["z_channels"], 1)
        if colorize_nlabels is not None:
            assert type(colorize_nlabels)==int
            self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1))
        if monitor is not None:
            self.monitor = monitor
        self.batch_resize_range = batch_resize_range
        if self.batch_resize_range is not None:
            print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.")

        self.use_ema = use_ema
        if self.use_ema:
            self.model_ema = LitEma(self)
            print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.")

        if ckpt_path is not None:
            self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys)
        self.scheduler_config = scheduler_config
        self.lr_g_factor = lr_g_factor

    @contextmanager
    def ema_scope(self, context=None):
        if self.use_ema:
            self.model_ema.store(self.parameters())
            self.model_ema.copy_to(self)
            if context is not None:
                print(f"{context}: Switched to EMA weights")
        try:
            yield None
        finally:
            if self.use_ema:
                self.model_ema.restore(self.parameters())
                if context is not None:
                    print(f"{context}: Restored training weights")

    def init_from_ckpt(self, path, ignore_keys=list()):
        sd = torch.load(path, map_location="cpu")["state_dict"]
        keys = list(sd.keys())
        for k in keys:
            for ik in ignore_keys:
                if k.startswith(ik):
                    print("Deleting key {} from state_dict.".format(k))
                    del sd[k]
        missing, unexpected = self.load_state_dict(sd, strict=False)
        print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
        if len(missing) > 0:
            print(f"Missing Keys: {missing}")
            print(f"Unexpected Keys: {unexpected}")

    def on_train_batch_end(self, *args, **kwargs):
        if self.use_ema:
            self.model_ema(self)

    def encode(self, x, tensor_mapping_index, ret_feature=False):
        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, tensor_mapping_index, ret_feature)
        else:
            h = self.encoder(x, tensor_mapping_index)
        h = self.quant_conv(h)
        quant, emb_loss, info = self.quantize(h)
        if ret_feature:
            return quant, emb_loss, info, phi_list
        return quant, emb_loss, info

    def encode_to_prequant(self, x):
        h = self.encoder(x)
        h = self.quant_conv(h)
        return h
    
    def encode_cond(self, b0, b1000x, b1000y, b1000z, tensor_mapping_index):
        batch_inputs = torch.cat([b0, b1000x, b1000y, b1000z], dim=0)
        b0_mapping_index = torch.ones_like(tensor_mapping_index) * 2
        b1000x_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000y_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000z_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        batch_mapping_indices = torch.cat([b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index], dim=0)
        _, _, _, batch_phi_list = self.encode(batch_inputs, batch_mapping_indices, ret_feature=True)
        
        batch_size = b0.shape[0]
        phi_b0_list = [phi[:batch_size] for phi in batch_phi_list]
        phi_b1000x_list = [phi[batch_size:2*batch_size] for phi in batch_phi_list]
        phi_b1000y_list = [phi[2*batch_size:3*batch_size] for phi in batch_phi_list]
        phi_b1000z_list = [phi[3*batch_size:] for phi in batch_phi_list]
        
        return phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list
    
    def decode(self, quant, tensor_mapping_index, b0, b1000x, b1000y, b1000z):
        phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(b0, b1000x, b1000y, b1000z, tensor_mapping_index)
        
        cond_dict = dict(
            phi_b0_list = phi_b0_list,
            phi_b1000x_list = phi_b1000x_list,
            phi_b1000y_list = phi_b1000y_list,
            phi_b1000z_list = phi_b1000z_list,
        )
        
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec

    def decode_code(self, code_b):
        quant_b = self.quantize.embed_code(code_b)
        dec = self.decode(quant_b)
        return dec
    
    def forward(self, input, tensor_mapping_index, b0, b1000x, b1000y, b1000z, return_pred_indices=False):
        quant, diff, (_,_,ind), phi_list = self.encode(input, tensor_mapping_index, ret_feature=True)
        dec = self.decode(quant, tensor_mapping_index, b0, b1000x, b1000y, b1000z)
        if return_pred_indices:
            return dec, diff, ind, phi_list 
        return dec, diff, phi_list

    def get_input(self, batch, k):
        x = batch[k]
        if len(x.shape) == 3:
            x = x[..., None]
        return x

    def training_step(self, batch, batch_idx):
        # https://github.com/pytorch/pytorch/issues/37142
        # try not to fool the heuristics
        x = self.get_input(batch, self.image_key)
        x_tensor_mapping_index = self.get_input(batch, 'tensor_mapping_index')
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        xrec, qloss, phi_list = self(x, x_tensor_mapping_index, x_b0, x_b1000x, x_b1000y, x_b1000z)
        
        # autoencode
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="train")
        
        self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True)
        return aeloss

    def validation_step(self, batch, batch_idx):
        log_dict = self._validation_step(batch, batch_idx)
        with self.ema_scope():
            log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema")
        return log_dict

    def _validation_step(self, batch, batch_idx, suffix=""):
        x = self.get_input(batch, self.image_key)
        x_tensor_mapping_index = self.get_input(batch, 'tensor_mapping_index')
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        xrec, qloss, phi_list = self(x, x_tensor_mapping_index, x_b0, x_b1000x, x_b1000y, x_b1000z)
        
        aeloss, log_dict_ae = self.loss(qloss, x, xrec, self.global_step, split="val"+suffix)

        rec_loss = log_dict_ae[f"val{suffix}/rec_loss"]
        self.log(f"val{suffix}/rec_loss", rec_loss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        self.log(f"val{suffix}/aeloss", aeloss,
                   prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True)
        if version.parse(pl.__version__) >= version.parse('1.4.0'):
            del log_dict_ae[f"val{suffix}/rec_loss"]
        self.log_dict(log_dict_ae)
        return self.log_dict

    def configure_optimizers(self):
        lr = self.learning_rate
        print("lr", lr)
        opt_ae = torch.optim.Adam(list(self.encoder.parameters())+
                                  list(self.decoder.parameters())+
                                  list(self.quantize.parameters())+
                                  list(self.quant_conv.parameters())+
                                  list(self.post_quant_conv.parameters()),
                                  lr=lr, betas=(0.5, 0.9))

        if self.scheduler_config is not None:
            scheduler = instantiate_from_config(self.scheduler_config)

            print("Setting up LambdaLR scheduler...")
            scheduler = [
                {
                    'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule),
                    'interval': 'step',
                    'frequency': 1
                }
            ]
            return [opt_ae], scheduler
        return [opt_ae], []

    def get_last_layer(self):
        # return self.decoder.conv_out.weight
        return None
        
    def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs):
        log = dict()
        x = self.get_input(batch, self.image_key)
        x_tensor_mapping_index = self.get_input(batch, 'tensor_mapping_index')
        x_b0 = self.get_input(batch, 'b0')
        x_b1000x = self.get_input(batch, 'b1000x')
        x_b1000y = self.get_input(batch, 'b1000y')
        x_b1000z = self.get_input(batch, 'b1000z')
        x = x.to(self.device)
        if only_inputs:
            log["inputs"] = x
            return log
        
        # xrec, qloss = self(x, x_tensor_mapping_index, x_b0, x_b1000x, x_b1000y, x_b1000z)
        xrec, _, _ = self(x, x_tensor_mapping_index, x_b0, x_b1000x, x_b1000y, x_b1000z)
        log["inputs"] = x
        log["reconstructions"] = xrec
        log["b0"] = x_b0
        log["b1000x"] = x_b1000x
        log["b1000y"] = x_b1000y
        log["b1000z"] = x_b1000z
        return log
        

    def to_rgb(self, x):
        assert self.image_key == "segmentation"
        if not hasattr(self, "colorize"):
            self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x))
        x = F.conv3d(x, weight=self.colorize)
        x = 2.*(x-x.min())/(x.max()-x.min()) - 1.
        return x

    
class VQDiGANInterface(VQDiGAN):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def encode(self, x, tensor_mapping_index, ret_feature=False):
        phi_list = None
        if ret_feature:
            h, phi_list = self.encoder(x, tensor_mapping_index, ret_feature)
        else:
            h = self.encoder(x, tensor_mapping_index)
        h = self.quant_conv(h)
        if ret_feature:
            return h, phi_list
        return h
    
    def encode_cond(self, b_x, tensor_mapping_index):
        batch_size, num_directions, num_channels, height, width, depth = b_x.shape
        b0_x, b1000x_x, b1000y_x, b1000z_x = b_x.unbind(dim=1)
        batch_inputs = torch.cat([b0_x, b1000x_x, b1000y_x, b1000z_x], dim=0)
        b0_mapping_index = torch.ones_like(tensor_mapping_index) * 2
        b1000x_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000y_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        b1000z_mapping_index = torch.ones_like(tensor_mapping_index) * 3
        batch_mapping_indices = torch.cat([b0_mapping_index, b1000x_mapping_index, b1000y_mapping_index, b1000z_mapping_index], dim=0)
        h, batch_phi_list = self.encode(batch_inputs, batch_mapping_indices, ret_feature=True)
        h = h.reshape(batch_size, num_directions, -1, *h.shape[-3:])
        
        phi_b0_list = [phi[:batch_size] for phi in batch_phi_list]
        phi_b1000x_list = [phi[batch_size:2*batch_size] for phi in batch_phi_list]
        phi_b1000y_list = [phi[2*batch_size:3*batch_size] for phi in batch_phi_list]
        phi_b1000z_list = [phi[3*batch_size:] for phi in batch_phi_list]
        
        return h, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list

    def encode_multidirection(self, x, tensor_mapping_index, ret_feature=True, full=True):
        if full:
            b_x = x[:, :4]
            b_tensor_mapping_index = tensor_mapping_index[:, 0]
            b_z, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(b_x, b_tensor_mapping_index)
            cond_dict = dict(
                phi_b0_list = phi_b0_list,
                phi_b1000x_list = phi_b1000x_list,
                phi_b1000y_list = phi_b1000y_list,
                phi_b1000z_list = phi_b1000z_list,
            )
            
            tensor_x = x[:, 4:]
            tensor_mapping_index = tensor_mapping_index[:, 4:]
            batch_size, num_directions, num_channels, height, width, depth = tensor_x.shape
            tensor_x = tensor_x.reshape(batch_size * num_directions, num_channels, height, width, depth)
            tensor_mapping_index = tensor_mapping_index.reshape(batch_size * num_directions)
            tensor_z = self.encoder(tensor_x, tensor_mapping_index)
            tensor_z = self.quant_conv(tensor_z)
            tensor_z = tensor_z.reshape(batch_size, num_directions, -1, *tensor_z.shape[-3:])
            
            z = torch.cat([b_z, tensor_z], dim=1)
            if ret_feature:
                return z, cond_dict
            return z
        else:
            tensor_mapping_index = tensor_mapping_index[:, 0]
            h, phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(x, tensor_mapping_index)
            h = self.quant_conv(h)
            cond_dict = dict(
                phi_b0_list = phi_b0_list,
                phi_b1000x_list = phi_b1000x_list,
                phi_b1000y_list = phi_b1000y_list,
                phi_b1000z_list = phi_b1000z_list,
            )
            if ret_feature:
                return h, cond_dict
            return h
    
    def decode(self, quant, tensor_mapping_index, b0, b1000x, b1000y, b1000z):
        phi_b0_list, phi_b1000x_list, phi_b1000y_list, phi_b1000z_list = self.encode_cond(b0, b1000x, b1000y, b1000z, tensor_mapping_index)
        
        cond_dict = dict(
            phi_b0_list = phi_b0_list,
            phi_b1000x_list = phi_b1000x_list,
            phi_b1000y_list = phi_b1000y_list,
            phi_b1000z_list = phi_b1000z_list,
        )
        
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec

    def decode(self, h, tensor_mapping_index, cond_dict, force_not_quantize=False):
        # also go through quantization layer
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        return dec

    def decode_multidirection(self, h, tensor_mapping_index, cond_dict, force_not_quantize=False):
        batch_size, num_directions, num_channels, height, width, depth = h.shape
        h = h.reshape(batch_size * num_directions, num_channels, height, width, depth)
        tensor_mapping_index = tensor_mapping_index.reshape(batch_size * num_directions)
        for i in cond_dict.keys():
            cond_dict[i] = [phi.expand(batch_size * num_directions, *phi.shape[1:]) for phi in cond_dict[i]]
            
        if not force_not_quantize:
            quant, emb_loss, info = self.quantize(h)
        else:
            quant = h
            
        quant = self.post_quant_conv(quant)
        dec = self.decoder(quant, y=tensor_mapping_index, cond_dict=cond_dict)
        dec = dec.reshape(batch_size, num_directions, *dec.shape[1:])
        return dec
    