import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from quantizer import GaussianVectorQuantizer
from model.util import Block, TopDownBlockBase, parse_bu_layer_string, parse_td_layer_string, parse_sq_layer_string
import itertools


class HierSQVAE(nn.Module):
    def __init__(self, cfgs, flgs):
        super(HierSQVAE, self).__init__()
        # Encoder/decoder
        str_block_bu = parse_bu_layer_string(cfgs.network.blocks_bu)
        str_block_td = parse_td_layer_string(cfgs.network.blocks_td)
        str_block_sq = parse_sq_layer_string(cfgs.network.blocks_sq)
        self.path_bottom_up = BottomUp(cfgs.network, cfgs.model.param_var_q_first, str_block_bu)
        self.path_top_down = TopDown(cfgs.network, cfgs.model, str_block_td, str_block_sq)

        # Resolution
        idx_end = []
        for idx, (res, up_rate) in enumerate(str_block_sq):
            if up_rate:
                idx_end.append(idx - 1)
        idx_end.append(idx)

        # Quantizer
        self.size_dict = []
        self.dim_dict = []
        self.codebook = nn.ParameterList()
        self.quantizer = []
        self.log_param_q_scalar_q = nn.Parameter(torch.tensor(cfgs.model.log_param_q_init, device="cuda"))
        for i, (res, _) in enumerate(str_block_sq):
            self.size_dict.append(cfgs.quantization.size_dict[i])
            self.dim_dict.append(cfgs.quantization.dim_dict[i])
            self.codebook.append(nn.Parameter(torch.randn(self.size_dict[i], self.dim_dict[i], device="cuda")))
            if i in idx_end:
                flg_loss_continuous = True
            else:
                flg_loss_continuous = False
            self.quantizer.append(
                GaussianVectorQuantizer(
                    self.size_dict[i], self.dim_dict[i], flg_loss_continuous, cfgs.quantization.temperature.init
                )
            )

        # Others
        self.progressive_coding = cfgs.model.flg_progressive_coding
        
    def forward(self, x_input, flg_train=False, flg_quant_det=True):
        activations = self.path_bottom_up(x_input)
        x_reconst, elbo, mse, stats = self.path_top_down(
            x_input, activations, self.quantizer, self.codebook, self.log_param_q_scalar_q, flg_train, flg_quant_det)
        perplexity_list = []
        indices_list = []
        for idx, statdict in enumerate(stats):
            perplexity_list.append(statdict['perplexity'].detach().cpu().item())
            indices_list.append(statdict['indices'])

        return x_reconst, dict(all=elbo, mse=mse, perplexity=perplexity_list, indices=indices_list)
    
    def ext_info(self, stats, key):
        value = []
        for idx, statdict in enumerate(stats):
            value.append(statdict[key].detach().cpu().item())
        
        return value
    
    def decode_from_indices(self, indices):
        z_quantized_list = []
        print(len(indices))
        for i in range(len(indices)):
            z_shape = indices[i].shape + (self.dim_dict[i],)
            indices_cb = indices[i].contiguous().view(-1,1,1,1).squeeze().unsqueeze(1)
            encodings_i = F.one_hot(indices_cb, num_classes=self.size_dict[i]).type_as(self.codebook[i])
            z_quantized_i = torch.matmul(encodings_i, self.codebook[i]).view(z_shape)
            z_quantized_list.append(z_quantized_i.permute(0, 3, 1, 2).contiguous())
        x_sampled = self.path_top_down.decode_from_latents(z_quantized_list)
        return x_sampled




class BottomUp(nn.Module):
    def __init__(self, cfgs, param_var_q, str_block):
        super().__init__()
        width = cfgs.width
        self.in_conv = nn.Sequential(
            nn.Conv2d(cfgs.image_channels, cfgs.width, 4, stride=2, padding=1),
            nn.ReLU(True))
        
        blocks_main = []
        for res, down_rate in str_block:
            use_3x3 = res > 2  # Don't use 3x3s for 1x1, 2x2 patches
            blocks_main.append(Block(width, int(width * cfgs.bottleneck_multiple), width, down_rate=down_rate, residual=True, use_3x3=use_3x3))
        if not (param_var_q in ["gaussian_1", "vmf"]):
            blocks_main.append(Block(width, int(width * cfgs.bottleneck_multiple), width * 2, down_rate=down_rate, residual=False, use_3x3=use_3x3))

        ## trick used in "very deep VAE" (perhaps no needed in our case?)
        n_blocks = len(str_block)
        for b in blocks_main:
            b.c4.weight.data *= np.sqrt(1 / n_blocks)
        self.blocks_main = nn.ModuleList(blocks_main)

    def forward(self, x):
        x = self.in_conv(x)
        activations = {}
        activations[str(x.shape[2])] = x
        for block in self.blocks_main:
            out = block(x)
            if len(out) == 1:
                x = out[0]
            else:
                x = out[0]
                activations[str(out[1].shape[2])] = out[1]
        activations[str(x.shape[2])] = x

        return activations


class InjSQBlock(TopDownBlockBase):
    def __init__(self, cfgs, ch_out, up_rate=1, n_blocks=[1, 1]):
        super(InjSQBlock, self).__init__(cfgs)
        # For upsampling
        blocks_upsample = []
        for ilayer in range(n_blocks[0]):
            blocks_upsample.append(Block(self.width, self.cond_width, self.width, residual=True, use_3x3=self.use_3x3))
        for i in range(int(np.log2(up_rate))):
            blocks_upsample.append(nn.ConvTranspose2d(self.width, self.width, 4, stride=2, padding=1))
            blocks_upsample.append(nn.ReLU(True))
            blocks_upsample.append(Block(self.width, self.cond_width, self.width, residual=True, use_3x3=self.use_3x3))
        blocks_upsample.append(nn.ConvTranspose2d(self.width, ch_out, 3, stride=1, padding=1))
        self.upsample = nn.Sequential(*blocks_upsample)
        # For posterior
        blocks_posterior = []
        for ilayer in range(n_blocks[1]):
            if ilayer == 0:
                blocks_posterior.append(Block(self.width * 2, self.cond_width, self.zdim, residual=False, use_3x3=self.use_3x3))
            else:
                blocks_posterior.append(Block(self.zdim, self.cond_width, self.zdim, residual=True, use_3x3=self.use_3x3))
        self.posterior = nn.Sequential(*blocks_posterior)
    
    def forward(self, z_cur, z_res, act, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        # Pre-process
        z_pass = self.upsample(z_cur)
        z_cur = self.posterior(torch.cat([z_pass, act], dim=1))
        # Quantization
        param_q = log_param_q_scalar_q.exp()
        stat = quantizer(z_cur, param_q, codebook, flg_train, flg_quant_det)
        # Residual and current latent tensors
        z_q = stat["z_q"]
        z_res = z_cur - z_q
        z_cur = z_pass + z_q

        return z_cur, z_res, stat
    
    def pass_through(self, z_cur, act):
        z_pass = self.upsample(z_cur)
        z_cur = self.posterior(torch.cat([z_pass, torch.zeros_like(act)], dim=1))
        return z_pass + z_cur

    def sample_uncond(self, x, quantizer, codebook, log_param_q_scalar, flg_train, flg_quant_det): ## To be impleternted
        pm, pv = self.prior(x).chunk(2, dim=1)
        param_q = self.get_param_q(pv, quantizer, log_param_q_scalar)
        stat = quantizer(
            pm, pm, param_q, param_q, codebook, flg_train, flg_quant_det)
        return z_quantized, x

    def forward_uncond(self, x, quantizer, codebook, log_param_q_scalar, flg_train, flg_quant_det): ## To be impleternted
        z, x = self.sample_uncond(x, quantizer, codebook, log_param_q_scalar, flg_train, flg_quant_det)
        x = self.out_conv(x + z)
        return x
    
    def decode_from_latent(self, z_cur, z_q):
        z_pass = self.upsample(z_cur)
        z_cur = z_pass + z_q
        return z_cur



class ResSQBlock(TopDownBlockBase):
    def __init__(self, cfgs):
        super(ResSQBlock, self).__init__(cfgs)

    def sample(self, x, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        # Decoding
        qm, qv = x, None
        # Quantization
        param_q = self.get_param_q(qv, quantizer, log_param_q_scalar_q, quantizer.param_var_q)
        stat = quantizer(qm, param_q, codebook, flg_train, flg_quant_det)
        
        return z_quantized, stat

    def forward(self, z_cur, z_res, act, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        # Quantization
        param_q = log_param_q_scalar_q.exp()
        stat = quantizer(z_res, param_q, codebook, flg_train, flg_quant_det)
        # Residual/rurrent latent
        z_q = stat["z_q"]
        z_res = z_res - z_q
        z_cur = z_cur + z_q
        return z_cur, z_res, stat
    
    def pass_through(self, z_cur, act):
        return z_cur
    
    def decode_from_latent(self, z_cur, z_q):
        return z_cur + z_q



class TopDown(nn.Module):
    def __init__(self, cfgs, cfgs_model, str_block_td, str_block_sq, flg_markov=True):
        super().__init__()
        self.str_block_sq = str_block_sq
        self.scale_width_noise = cfgs_model.scale_width_noise
        self.scale_down_rate = cfgs_model.scale_down_rate
        self.flg_progressive_coding = cfgs_model.flg_progressive_coding
        self.cfgs = cfgs
        self.flg_arelbo = cfgs_model.flg_arelbo
        blocks_sq = []
        for il, (res, up_rate) in enumerate(self.str_block_sq):
            if up_rate:
                blocks_sq.append(InjSQBlock(cfgs, cfgs.width, up_rate))
            else:
                blocks_sq.append(ResSQBlock(cfgs))
        self.blocks_sq = nn.Sequential(*blocks_sq)

        blocks_td = []
        for res, up_rate in str_block_td:
            use_3x3 = res > 2  # Don't use 3x3s for 1x1, 2x2 patches
            if up_rate:
                blocks_td.append(nn.ConvTranspose2d(cfgs.width, cfgs.width, 4, stride=2, padding=1))
            else:
                blocks_td.append(Block(cfgs.width, int(cfgs.width * cfgs.bottleneck_multiple), cfgs.width, down_rate=None, residual=True, use_3x3=use_3x3))
        blocks_td.append(nn.ConvTranspose2d(cfgs.width, 3, 3, stride=1, padding=1))
        blocks_td.append(nn.Sigmoid())
        self.final_fn = nn.Sequential(*blocks_td)

    def forward(self, x_input, activations, quantizer, codebook, log_param_q_scalar_q, flg_train, flg_quant_det):
        stats = []
        loss = 0
        idx_res = 0
        idx_start = 0
        for idx, (res, up_rate) in enumerate(self.str_block_sq):
            if idx == 0:
                z_res = activations[str(res)]
                z_cur = torch.zeros_like(z_res)
            if up_rate:
                idx_res += 1
                idx_start = idx
            z_cur, z_res, block_stats = self.blocks_sq[idx](z_cur, z_res, activations[str(res)],
                                            quantizer[idx], codebook[idx], log_param_q_scalar_q[idx_start:idx+1], flg_train, flg_quant_det)
            if self.flg_progressive_coding:
                x_rec = self.forward_pass(z_cur, activations, idx)
                loss_rec, mse = self._calc_loss_rec(x_rec, x_input, idx)
                loss = loss + loss_rec
                block_stats["mse"] = mse
            loss = loss + block_stats["kl"]
            stats.append(block_stats)
        if self.flg_progressive_coding:
            mse = block_stats["mse"]
        else:
            x_rec = self.final_fn(z_cur)
            loss_rec, mse = self._calc_loss_rec(x_rec, x_input)
            loss = loss + loss_rec

        return x_rec, loss, mse, stats
    
    def forward_pass(self, z_cur, activations, idx):
        for idx_rest, (res, up_rate) in enumerate(self.str_block_sq):
            if idx_rest > idx:
                z_cur = self.blocks_sq[idx_rest].pass_through(z_cur, activations[str(res)])

        return self.final_fn(z_cur)


    def _calc_loss_rec(self, x_reconst, x_orig, idx=None):
        bs, *T = x_orig.shape
        dim_x = np.prod(T)
        mse_real = F.mse_loss(x_reconst, x_orig, reduction="sum") / bs
        if idx is None:
            noise = 0
            mse = mse_real
        else:
            width_noise = self.scale_width_noise[idx]
            down_rate = self.scale_down_rate[idx]
            x_reconst = F.avg_pool2d(x_reconst, kernel_size=down_rate, stride=down_rate) 
            x_orig = F.avg_pool2d(x_orig, kernel_size=down_rate, stride=down_rate)
            bs, *T = x_orig.shape
            dim_x = np.prod(T)
            mse = F.mse_loss(x_reconst, x_orig, reduction="sum") / bs + dim_x * width_noise**2 / 12
            
        # "Preventing Oversmoothing in VAE via Generalized Variance Parameterization"
        # https://arxiv.org/abs/2102.08663
        distortion = dim_x * torch.log(mse) / 2

        return distortion, mse_real


    def forward_uncond(self, x, quantizer, codebook, log_param_q_scalar, flg_train, flg_quant_det):
        for idx, block in enumerate(self.blocks_td):
            x = block.forward_uncond(x, quantizer[idx], codebook[idx], log_param_q_scalar[idx], flg_train, flg_quant_det)
        return self.final_fn(x)
    

    def decode_from_latents(self, z_quantized_list):
        for idx, block in enumerate(self.blocks_sq):
            if idx == 0:
                z_cur = torch.zeros_like(z_quantized_list[idx])
            z_cur = block.decode_from_latent(z_cur, z_quantized_list[idx])
        
        return self.final_fn(z_cur)

