import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from quantizer import GaussianVectorQuantizer
from model.util import Block, TopDownBlockBase, parse_bottom_up_layer_string, parse_top_down_layer_string
import itertools


class ResModel(nn.Module):
    def __init__(self, cfgs, flgs, model_name, device="cuda"):
        super(ResModel, self).__init__()

        # Common setting
        self.device = device
        
        # RSQ-VAE or RVQ-VAE
        self.model_name = model_name

        # Data space
        self.dim_x = cfgs.dataset.dim_x
        self.flg_arelbo = flgs.arelbo 
        if not self.flg_arelbo:
            self.logvar_x = nn.Parameter(torch.tensor(np.log(0.1)))

        # Codebook
        self.num_residual = len(cfgs.quantization.size_dict)
        self.size_dict = []
        self.dim_dict = []
        self.codebook = nn.ParameterList()
        self.quantizer = []
        self.log_param_q_scalar_q = nn.Parameter(
            torch.tensor(cfgs.model.log_param_q_init, device=self.device)
        )
        for idx in range(self.num_residual):
            self.size_dict.append(cfgs.quantization.size_dict[idx])
            self.dim_dict.append(cfgs.quantization.dim_dict[idx])
            self.codebook.append(nn.Parameter(
                torch.randn(self.size_dict[idx], self.dim_dict[idx], device=self.device)))
            if self.model_name == "ressqvae":
                self.quantizer.append(
                    GaussianVectorQuantizer(
                        size_dict=self.size_dict[idx],
                        dim_dict=self.dim_dict[idx],
                        flg_loss_continuous=(idx==self.num_residual-1),
                        temperature=cfgs.quantization.temperature.init,
                        device=self.device
                    )
                )
            elif self.model_name == "resvqvae":
                self.quantizer.append(
                    DeterministicVectorQuantizer(
                        self.size_dict[idx], self.dim_dict[idx]
                    )
                )

        # Encoder/decoder
        self.encoder = BottomUp(
            cfgs=cfgs.network,
            param_var_q=cfgs.model.param_var_q_first)
        self.decoder = TopDown(
            cfgs=cfgs.network,
            cfgs_model=cfgs.model,
            num_residual=self.num_residual,
            flg_codebook_share=cfgs.quantization.flg_codebook_share)
        
        
    def forward(self, x, flg_train=False, flg_quant_det=True):
        activation = self.encoder(x)
        xs_reconst, stats = self.decoder(
            activation=activation, 
            quantizer=self.quantizer,
            codebook=self.codebook,
            log_param_q_scalar_q=self.log_param_q_scalar_q,
            flg_train=flg_train,
            flg_quant_det = flg_quant_det)

        if self.model_name == "resvqvae":
            # Codebook update with EMA
            for i, cb in enumerate(self.codebook):
                self.codebook[i] = stats[i]["codebook"]
        
        # Objective function + information summarization
        distortion, mse = self._calc_distortion(xs_reconst, x)
        kl_term = 0.0
        perplexity_list = []
        indices_list = []
        for idx, statdict in enumerate(stats):
            kl_term += statdict['kl']
            perplexity_list.append(statdict['perplexity'].detach().cpu().item())
            indices_list.append(statdict['indices'])
        elbo = distortion + kl_term
        stats = dict(all=elbo, mse=mse, kl_term=kl_term, perplexity=perplexity_list, indices=indices_list)

        return xs_reconst, stats
    
    def forward_uncond_samples(self, x, flg_quant_det=True):
        activation = self.encoder(x)
        x_reconst = self.decoder.forward_uncond(activation[str(self.res[0])], self.quantizer, self.codebook, self.log_param_q_scalar_p, False, flg_quant_det)
        return x_reconst
        
    def _calc_distortion(self, x_rec, x):
        raise NotImplementedError()
    
    def decode_from_indices(self, indices):
        z_quantized_list = []
        for i in range(len(indices)):
            if self.decoder.flg_codebook_share:
                codebook = self.codebook[0]
            else:
                codebook = self.codebook[i]
            z_shape = indices[i].shape + (self.dim_dict[i],)
            indices_cb = indices[i].contiguous().view(-1,1,1,1).squeeze().unsqueeze(1)
            encodings_i = F.one_hot(indices_cb, num_classes=self.size_dict[i]).type_as(codebook)
            z_quantized_i = torch.matmul(encodings_i, codebook).view(z_shape)
            z_quantized_list.append(z_quantized_i.permute(0, 3, 1, 2).contiguous())
        x_sampled = self.decoder.decode_from_latents(z_quantized_list)
        return x_sampled


class ResSQVAE(ResModel):
    def __init__(self, cfgs, flgs):
        super(ResSQVAE, self).__init__(cfgs, flgs, model_name="ressqvae")
    
    def _calc_distortion(self, x_rec, x):
        bs = x.shape[0]
        mse = F.mse_loss(x_rec, x, reduction="sum") / bs

        if self.flg_arelbo:
            # "Preventing oversmoothing in VAE via generalized variance parameterization"
            # https://www.sciencedirect.com/science/article/pii/S0925231222010591
            distortion = self.dim_x * torch.log(mse) / 2
        else:
            distortion = mse / (2*self.logvar_x.exp()) + self.dim_x * self.logvar_x / 2

        return distortion, mse


class BottomUp(nn.Module):
    def __init__(self, cfgs, param_var_q):
        super().__init__()
        # Pre-processing
        width = cfgs.width // 2**int(np.log2((cfgs.image_size / cfgs.res_top)) - 1)
        self.in_conv = nn.Sequential(
            nn.Conv2d(cfgs.image_channels, width, 4, stride=2, padding=1),
            nn.ReLU(True))
        
        # Main NN
        blocks_main = []
        blockstr = parse_bottom_up_layer_string(cfgs.blocks_bu)
        for res, down_rate in blockstr:
            use_3x3 = res > 2  # Don't use 3x3s for 1x1, 2x2 patches
            if down_rate:
                blocks_main.append(nn.Conv2d(width, width * 2, 4, stride=2, padding=1))
                width *= 2
            else:
                blocks_main.append(
                    Block(
                        in_width=width,
                        middle_width=int(width * cfgs.bottleneck_multiple),
                        out_width=width,
                        down_rate=None,
                        residual=True,
                        use_3x3=use_3x3))
        self.main = nn.Sequential(*blocks_main)

    def forward(self, x):
        x = self.in_conv(x)
        x = self.main(x)

        return x
 

class TopDownBlock(TopDownBlockBase):
    def __init__(self, cfgs):
        super(TopDownBlock, self).__init__(cfgs)
    
    def forward(self, z_cur, z_res, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        # Quantization
        param_q = log_param_q_scalar_q.exp()
        stat = quantizer(
            z_pos=z_res,
            var_q_pos=param_q,
            codebook=codebook,
            flg_train=flg_train,
            flg_quant_det=flg_quant_det
        )
        # Residual/rurrent latent
        z_q = stat["z_q"]
        z_res = z_res - z_q
        z_cur = z_cur + z_q

        return z_cur, z_res, stat
    
    def decode_from_latent(self, x, z_quantized):
        diff_z = x - z_quantized
        return z_quantized, diff_z


class TopDown(nn.Module):
    def __init__(self, cfgs, cfgs_model, num_residual, flg_codebook_share=False):
        super().__init__()
        self.cfgs = cfgs
        self.flg_codebook_share = flg_codebook_share

        # Quantizer
        blocks_sq = []
        for i in range(num_residual):
            blocks_sq.append(TopDownBlock(cfgs))
        self.blocks_sq = nn.ModuleList(blocks_sq)

        # Top-down network
        blocks_final = []
        width = cfgs.width
        blockstr = parse_bottom_up_layer_string(cfgs.blocks_td)
        for res, down_rate in blockstr:
            use_3x3 = res > 2  # Don't use 3x3s for 1x1, 2x2 patches
            if down_rate:
                blocks_final.append(
                    nn.ConvTranspose2d(width, width // 2, 4, stride=2, padding=1))
                width //= 2
            else:
                blocks_final.append(
                    Block(
                        in_width=width,
                        middle_width=int(width * cfgs.bottleneck_multiple),
                        out_width=width,
                        down_rate=None,
                        residual=True,
                        use_3x3=use_3x3))
        blocks_final.append(nn.ConvTranspose2d(width, 3, 3, stride=1, padding=1))
        if cfgs.last_activate == "sigmoid":
            blocks_final.append(nn.Sigmoid())
        elif cfgs.last_activate == "tanh":
            blocks_final.append(nn.Tanh())
        self.final_fn = nn.Sequential(*blocks_final)

    def forward(self, activation, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        stats = []
        for idx, block in enumerate(self.blocks_sq):

            # Initialization
            if idx == 0:
                z_res = activation
                z_cur = torch.zeros_like(z_res)

            # Codebook is shared or not
            if self.flg_codebook_share:
                idx_codebook = 0
            else:
                idx_codebook = idx

            # Quantization block
            z_cur, z_res, block_stats = block(
                z_cur=z_cur,
                z_res=z_res,
                quantizer=quantizer[idx],
                codebook=codebook[idx_codebook],
                log_param_q_scalar_q=log_param_q_scalar_q[:idx+1],
                flg_train=flg_train,
                flg_quant_det=flg_quant_det
            )
            stats.append(block_stats)
        
        x_rec = self.final_fn(z_cur)

        return x_rec, stats

    def forward_uncond(self, x, quantizer, codebook, log_param_q_scalar, flg_train, flg_quant_det):
        for idx, block in enumerate(self.blocks_sq):
            x = block.forward_uncond(x, quantizer[idx], codebook[idx], log_param_q_scalar[idx], flg_train, flg_quant_det)
        return self.final_fn(x)

    # For generation
    def decode_from_latents(self, z_quantized_list):
        for idx, block in enumerate(self.blocks_sq):
            if idx == 0:
                z_all, diff_z = block.decode_from_latent(0, z_quantized_list[idx])
                z_gen = z_all
            else:
                z, diff_z = block.decode_from_latent(diff_z, z_quantized_list[idx])
                z_gen = z_gen + z
        
        return self.final_fn(z_gen)
