import torch
import torch.nn.functional as F

from einops import rearrange, einsum
from typing import List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
from torch import distributed as tdist, nn as nn
from torch.nn import functional as F

from . import dist
from .abstract_modules.base_quantizer import BaseVectorQuantizer


class VectorQuantizer(BaseVectorQuantizer):

    def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25):

        """
        Original VectorQuantizer with straight through gradient estimator (loss is optimized on inputs and codebook)
        :param num_embeddings: size of the latent dictionary (num of embedding vectors).
        :param embedding_dim: size of a single tensor in dict.
        :param commitment_cost: scaling factor for e_loss
        """

        super().__init__(num_embeddings, embedding_dim)

        self.commitment_cost = commitment_cost

    def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float):
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss
        """

        b, c, h, w = x.shape
        device = x.device

        # Flat input to vectors of embedding dim = C.
        flat_x = rearrange(x, 'b c h w -> (b h w) c')

        # Calculate distances of each vector w.r.t the dict
        # distances is a matrix (B*H*W, codebook_size)
        distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True)
                     + torch.sum(self.codebook.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_x, self.codebook.weight.t()))

        # Get indices of the closest vector in dict, and create a mask on the correct indexes
        # encoding_indices = (num_vectors_in_batch, 1)
        # Mask = (num_vectors_in_batch, codebook_dim)
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=device)
        encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)

        # Quantize and un-flat
        quantized = torch.matmul(encodings, self.codebook.weight)

        # Loss functions
        e_loss = self.commitment_cost * F.mse_loss(quantized.detach(), flat_x)
        q_loss = F.mse_loss(quantized, flat_x.detach())

        # during backpropagation quantized = inputs (copy gradient trick)
        quantized = flat_x + (quantized - flat_x).detach()

        quantized = rearrange(quantized, '(b h w) c -> b c h w', b=b, h=h, w=w)
        encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=b, h=h, w=w).detach()

        return quantized, encoding_indices, q_loss + e_loss

    @torch.no_grad()
    def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor:
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return flat codebook indices (B, H * W)
        """
        b, c, h, w = x.shape

        # Flat input to vectors of embedding dim = C.
        flat_x = rearrange(x, 'b c h w -> (b h w) c')

        # Calculate distances of each vector w.r.t the dict
        # distances is a matrix (B*H*W, codebook_size)
        distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True)
                     + torch.sum(self.codebook.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_x, self.codebook.weight.t()))

        # Get indices of the closest vector in dict
        encoding_indices = torch.argmin(distances, dim=1)
        encoding_indices = rearrange(encoding_indices, '(b h w) -> b (h w)', b=b, h=h, w=w)

        return encoding_indices


class EMAVectorQuantizer(BaseVectorQuantizer):

    def __init__(self, num_embeddings: int, embedding_dim: int, commitment_cost: float = 0.25, decay: float = 0.95,
                 epsilon: float = 1e-5):

        """
        EMA ALGORITHM
        Each codebook entry is updated according to the encoder outputs who selected it.
        The important thing is that the codebook updating is not a loss term anymore.
        Specifically, for every codebook item wi, the mean code mi and usage count Ni are tracked:
        Ni ← Ni · γ + ni(1 − γ),
        mi ← mi · γ + Xnij e(xj )(1 − γ),
        wi ← mi Ni
        where γ is a discount factor

        :param num_embeddings: size of the latent dictionary (num of embedding vectors).
        :param embedding_dim: size of a single tensor in dictionary
        :param commitment_cost: scaling factor for e_loss
        :param decay: decay for EMA updating
        :param epsilon: smoothing parameters for EMA weights
        """

        super().__init__(num_embeddings, embedding_dim)

        self.commitment_cost = commitment_cost

        # EMA does not require grad
        self.codebook.requires_grad_(False)

        # ema parameters
        # ema usage count: total count of each embedding trough epochs
        self.register_buffer('ema_count', torch.zeros(num_embeddings))

        # same size as dict, initialized as codebook
        # the updated means
        self.register_buffer('ema_weight', torch.empty((self.num_embeddings, self.embedding_dim)))
        self.ema_weight.data.uniform_(-1 / self.num_embeddings, 1 / self.num_embeddings)

        self.decay = decay
        self.epsilon = epsilon

    def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float):
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss
        """

        b, c, h, w = x.shape
        device = x.device

        # Flat input to vectors of embedding dim = C.
        flat_x = rearrange(x, 'b c h w -> (b h w) c')

        # Calculate distances of each vector w.r.t the dict
        # distances is a matrix (B*H*W, codebook_size)
        distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True)
                     + torch.sum(self.codebook.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_x, self.codebook.weight.t()))

        # Get indices of the closest vector in dict, and create a mask on the correct indexes
        # encoding_indices = (num_vectors_in_batch, 1)
        # Mask = (num_vectors_in_batch, codebook_dim)
        encoding_indices = torch.argmin(distances, dim=1)
        encodings = torch.zeros(encoding_indices.shape[0], self.num_embeddings, device=device)
        encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)

        # Quantize and un-flat
        quantized = torch.matmul(encodings, self.codebook.weight)

        # Use EMA to update the embedding vectors
        # Update a codebook vector as the mean of the encoder outputs that are closer to it
        # Calculate the usage count of codes and the mean code, then update the codebook vector dividing the two
        if self.training:
            with torch.no_grad():
                ema_count = self.get_buffer('ema_count') * self.decay + (1 - self.decay) * torch.sum(encodings, 0)

                # Laplace smoothing of the ema count
                self.ema_count = (ema_count + self.epsilon) / (b + self.num_embeddings * self.epsilon) * b

                dw = torch.matmul(encodings.t(), flat_x)
                self.ema_weight = self.get_buffer('ema_weight') * self.decay + (1 - self.decay) * dw

                self.codebook.weight.data = self.get_buffer('ema_weight') / self.get_buffer('ema_count').unsqueeze(1)

        # Loss function (only the inputs are updated)
        e_loss = self.commitment_cost * F.mse_loss(quantized.detach(), flat_x)

        # during backpropagation quantized = inputs (copy gradient trick)
        quantized = flat_x + (quantized - flat_x).detach()

        quantized = rearrange(quantized, '(b h w) c -> b c h w', b=b, h=h, w=w)
        encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=b, h=h, w=w).detach()

        return quantized, encoding_indices, e_loss

    @torch.no_grad()
    def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor:
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return flat codebook indices (B, H * W)
        """
        b, c, h, w = x.shape

        # Flat input to vectors of embedding dim = C.
        flat_x = rearrange(x, 'b c h w -> (b h w) c')

        # Calculate distances of each vector w.r.t the dict
        # distances is a matrix (B*H*W, codebook_size)
        distances = (torch.sum(flat_x ** 2, dim=1, keepdim=True)
                     + torch.sum(self.codebook.weight ** 2, dim=1)
                     - 2 * torch.matmul(flat_x, self.codebook.weight.t()))

        # Get indices of the closest vector in dict
        encoding_indices = torch.argmin(distances, dim=1)
        encoding_indices = rearrange(encoding_indices, '(b h w) -> b (h w)', b=b, h=h, w=w)

        return encoding_indices


class GumbelVectorQuantizer(BaseVectorQuantizer):
    def __init__(self, num_embeddings: int, embedding_dim: int, straight_through: bool = False, temp: float = 1.0,
                 kl_cost: float = 5e-4):
        """
        :param num_embeddings: size of the latent dictionary (num of embedding vectors).
        :param embedding_dim: size of a single tensor in dict.
        :param straight_through: if True, will one-hot quantize, but still differentiate as if it is the soft sample
        :param temp: temperature parameter for gumbel softmax
        :param kl_cost: cost for kl divergence
        """
        super().__init__(num_embeddings, embedding_dim)

        self.x_to_logits = torch.nn.Conv2d(num_embeddings, num_embeddings, 1)
        self.straight_through = straight_through
        self.temp = temp
        self.kl_cost = kl_cost

    def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float):
        """
        :param x: tensors (output of the Encoder - B,N,H,W). Note that N = number of embeddings in dict!
        :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss
        """

        # deterministic quantization during inference
        hard = self.straight_through if self.training else True

        logits = self.x_to_logits(x)
        soft_one_hot = F.gumbel_softmax(logits, tau=self.temp, dim=1, hard=hard)
        quantized = einsum(soft_one_hot, self.get_codebook(), 'b n h w, n d -> b d h w')

        # + kl divergence to the prior (uniform) loss, increase cb usage
        # Note:
        #       KL(P(x), Q(x)) = sum_x (P(x) * log(P(x) / Q(x)))
        #       in this case: P(x) is qy, Q(x) is uniform distribution (1 / num_embeddings)
        qy = F.softmax(logits, dim=1)
        kl_loss = self.kl_cost * torch.sum(qy * torch.log(qy * self.num_embeddings + 1e-10), dim=1).mean()

        encoding_indices = soft_one_hot.argmax(dim=1).detach()

        return quantized, encoding_indices, kl_loss

    def get_consts(self) -> (float, float):
        """
        return temp, kl_cost
        """
        return self.temp, self.kl_cost

    def set_consts(self, temp: float = None, kl_cost: float = None) -> None:
        """
        update values for temp, kl_cost
        :param temp: new value for temperature (if not None)
        :param kl_cost: new value for kl_cost (if not None)
        """
        if temp is not None:
            self.temp = temp

        if kl_cost is not None:
            self.kl_cost = kl_cost

    @torch.no_grad()
    def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor:
        """
        :param x: tensors (output of the Encoder - B,N,H,W). Note that N = number of embeddings in dict!
        :return flat codebook indices (B, H * W)
        """

        soft_one_hot = F.gumbel_softmax(x, tau=1.0, dim=1, hard=True)
        encoding_indices = soft_one_hot.argmax(dim=1)
        return encoding_indices


class EntropyVectorQuantizer(BaseVectorQuantizer):

    def __init__(self, num_embeddings: int, embedding_dim: int, ent_loss_ratio: float = 0.1,
                 ent_temperature: float = 0.01, ent_loss_type: str = 'softmax', commitment_cost: float = 0.25):

        super().__init__(num_embeddings, embedding_dim)

        # hparams
        self.ent_loss_ratio = ent_loss_ratio
        self.ent_temperature = ent_temperature
        self.ent_loss_type = ent_loss_type
        self.commitment_cost = commitment_cost

    def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.IntTensor, float):
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return quantized_x (B, D, H, W), detached codes (B, H*W), latent_loss
        """

        def entropy_loss(affinity: torch.Tensor, temperature: float, loss_type: str = 'softmax'):
            """
            Increase codebook usage by maximizing entropy

            affinity: 2D tensor of size Dim, n_classes
            """

            n_classes = affinity.shape[-1]

            affinity = torch.div(affinity, temperature)
            probs = F.softmax(affinity, dim=-1)
            log_probs = F.log_softmax(affinity + 1e-5, dim=-1)

            if loss_type == "softmax":
                target_probs = probs
            elif loss_type == "argmax":
                codes = torch.argmax(affinity, dim=-1)
                one_hots = F.one_hot(codes, n_classes).to(codes)
                one_hots = probs - (probs - one_hots).detach()
                target_probs = one_hots
            else:
                raise ValueError("Entropy loss {} not supported".format(loss_type))

            avg_probs = torch.mean(target_probs, dim=0)
            avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + 1e-5))
            sample_entropy = -torch.mean(torch.sum(target_probs * log_probs, dim=-1))
            return sample_entropy - avg_entropy

        batch_size, c, h, w = x.shape

        # compute distances
        flat_x = rearrange(x, 'b c h w -> (b h w) c')
        transposed_cb_weights = self.get_codebook().T

        # final distance vector is (B * Latent_Dim, Codebook Dim)
        a2 = torch.sum(flat_x ** 2, dim=1, keepdim=True)
        b2 = torch.sum(transposed_cb_weights ** 2, dim=0, keepdim=True)
        ab = torch.matmul(flat_x, transposed_cb_weights)
        distances = a2 - 2 * ab + b2

        # get indices and quantized
        encoding_indices = torch.argmin(distances, dim=1)
        quantized = self.codebook(encoding_indices)
        quantized = rearrange(quantized, '(b h w) c -> b c h w', b=batch_size, h=h, w=w)
        encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=batch_size, h=h, w=w).detach()

        # compute_loss
        e_latent_loss = torch.mean((quantized.detach() - x) ** 2) * self.commitment_cost
        q_latent_loss = torch.mean((quantized - x.detach()) ** 2)
        ent_loss = entropy_loss(-distances, self.ent_temperature, self.ent_loss_type) * self.ent_loss_ratio
        loss = e_latent_loss + q_latent_loss + ent_loss

        quantized = x + (quantized - x).detach()

        return quantized, encoding_indices, loss

    @torch.no_grad()
    def vec_to_codes(self, x: torch.Tensor) -> torch.IntTensor:
        """
        :param x: tensors (output of the Encoder - B,D,H,W).
        :return flat codebook indices (B, H * W)
        """

        batch_size, c, h, w = x.shape

        # compute distances
        flat_x = rearrange(x, 'b c h w -> (b h w) c')
        transposed_cb_weights = self.get_codebook().T

        # final distance vector is (B * Latent_Dim, Codebook Dim)
        a2 = torch.sum(flat_x ** 2, dim=1, keepdim=True)
        b2 = torch.sum(transposed_cb_weights ** 2, dim=0, keepdim=True)
        ab = torch.matmul(flat_x, transposed_cb_weights)
        distances = a2 - 2 * ab + b2

        # get indices and quantized
        encoding_indices = torch.argmin(distances, dim=1)
        encoding_indices = rearrange(encoding_indices, '(b h w)-> b (h w)', b=batch_size, h=h, w=w)

        return encoding_indices



class VectorQuantizer2(nn.Module):
    # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
    def __init__(
        self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
        default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4,  # share_quant_resi: args.qsr
    ):
        super().__init__()
        self.vocab_size: int = vocab_size
        self.Cvae: int = Cvae
        self.using_znorm: bool = using_znorm
        self.v_patch_nums: Tuple[int] = v_patch_nums
        
        self.quant_resi_ratio = quant_resi
        if share_quant_resi == 0:   # non-shared: \phi_{1 to K} for K scales
            self.quant_resi = PhiNonShared([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(default_qresi_counts or len(self.v_patch_nums))])
        elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
            self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
        else:                       # partially shared: \phi_{1 to share_quant_resi} for K scales
            self.quant_resi = PhiPartiallyShared(nn.ModuleList([(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in range(share_quant_resi)]))
        
        self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
        self.record_hit = 0
        
        self.beta: float = beta
        self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
        
        # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
        self.prog_si = -1   # progressive training: not supported yet, prog_si always -1
    
    def eini(self, eini):
        if eini > 0: nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
        elif eini < 0: self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
    
    def extra_repr(self) -> str:
        return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta}  |  S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
    
    # ===================== `forward` is only used in VAE training =====================
    def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
        dtype = f_BChw.dtype
        if dtype != torch.float32: f_BChw = f_BChw.float()
        B, C, H, W = f_BChw.shape
        f_no_grad = f_BChw.detach()
        
        f_rest = f_no_grad.clone()
        f_hat = torch.zeros_like(f_rest)
        
        with torch.cuda.amp.autocast(enabled=False):
            mean_vq_loss: torch.Tensor = 0.0
            vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
            SN = len(self.v_patch_nums)
            for si, pn in enumerate(self.v_patch_nums): # from small to large
                # find the nearest embedding
                if self.using_znorm:
                    rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
                    rest_NC = F.normalize(rest_NC, dim=-1)
                    idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
                else:
                    rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
                    d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
                    d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)
                    idx_N = torch.argmin(d_no_grad, dim=1)
                
                hit_V = idx_N.bincount(minlength=self.vocab_size).float()
                if self.training:
                    if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
                
                # calc loss
                idx_Bhw = idx_N.view(B, pn, pn)
                h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
                h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
                f_hat = f_hat + h_BChw
                f_rest -= h_BChw
                
                if self.training and dist.initialized():
                    handler.wait()
                    if self.record_hit == 0: self.ema_vocab_hit_SV[si].copy_(hit_V)
                    elif self.record_hit < 100: self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
                    else: self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
                    self.record_hit += 1
                vocab_hit_V.add_(hit_V)
                mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
            
            mean_vq_loss *= 1. / SN
            f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
        
        margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
        # margin = pn*pn / 100
        if ret_usages: usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in enumerate(self.v_patch_nums)]
        else: usages = None
        return f_hat, usages, mean_vq_loss
    # ===================== `forward` is only used in VAE training =====================
    
    def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
        ls_f_hat_BChw = []
        B = ms_h_BChw[0].shape[0]
        H = W = self.v_patch_nums[-1]
        SN = len(self.v_patch_nums)
        if all_to_max_scale:
            f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
            for si, pn in enumerate(self.v_patch_nums): # from small to large
                h_BChw = ms_h_BChw[si]
                if si < len(self.v_patch_nums) - 1:
                    h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bicubic')
                h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
                f_hat.add_(h_BChw)
                if last_one: ls_f_hat_BChw = f_hat
                else: ls_f_hat_BChw.append(f_hat.clone())
        else:
            # WARNING: this is not the case in VQ-VAE training (where we'll interpolate every token map to the max scale), so it may cause some training-inference inconsistency
            # WARNING: this should only be used for experimental visualization
            f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0], dtype=torch.float32)
            for si, pn in enumerate(self.v_patch_nums): # from small to large
                f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bicubic')
                h_BChw = self.quant_resi[si/(SN-1)](ms_h_BChw[si])
                f_hat.add_(h_BChw)
                if last_one: ls_f_hat_BChw = f_hat
                else: ls_f_hat_BChw.append(f_hat)
        
        return ls_f_hat_BChw
    
    def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[Union[torch.Tensor, torch.LongTensor]]:  # z_BChw is the feature from inp_img_no_grad
        B, C, H, W = f_BChw.shape
        f_no_grad = f_BChw.detach()
        f_rest = f_no_grad.clone()
        f_hat = torch.zeros_like(f_rest)
        
        f_hat_or_idx_Bl: List[torch.Tensor] = []
        
        patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in (v_patch_nums or self.v_patch_nums)]    # from small to large
        assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
        
        SN = len(patch_hws)
        for si, (ph, pw) in enumerate(patch_hws): # from small to large
            if 0 <= self.prog_si < si: break    # progressive training: not supported yet, prog_si always -1
            # find the nearest embedding
            z_NC = F.interpolate(f_rest, size=(ph, pw), mode='area').permute(0, 2, 3, 1).reshape(-1, C) if (si != SN-1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
            if self.using_znorm:
                z_NC = F.normalize(z_NC, dim=-1)
                idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
            else:
                d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(self.embedding.weight.data.square(), dim=1, keepdim=False)
                d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1)  # (B*h*w, vocab_size)
                idx_N = torch.argmin(d_no_grad, dim=1)
            
            idx_Bhw = idx_N.view(B, ph, pw)
            h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W), mode='bicubic').contiguous() if (si != SN-1) else self.embedding(idx_Bhw).permute(0, 3, 1, 2).contiguous()
            h_BChw = self.quant_resi[si/(SN-1)](h_BChw)
            f_hat.add_(h_BChw)
            f_rest.sub_(h_BChw)
            f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph*pw))
        
        return f_hat_or_idx_Bl
    
    # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
    def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
        next_scales = []
        B = gt_ms_idx_Bl[0].shape[0]
        C = self.Cvae
        H = W = self.v_patch_nums[-1]
        SN = len(self.v_patch_nums)
        
        f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
        pn_next: int = self.v_patch_nums[0]
        for si in range(SN-1):
            if self.prog_si == 0 or (0 <= self.prog_si-1 < si): break   # progressive training: not supported yet, prog_si always -1
            h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next), size=(H, W), mode='bicubic')
            f_hat.add_(self.quant_resi[si/(SN-1)](h_BChw))
            pn_next = self.v_patch_nums[si+1]
            next_scales.append(F.interpolate(f_hat, size=(pn_next, pn_next), mode='area').view(B, C, -1).transpose(1, 2))
        return torch.cat(next_scales, dim=1) if len(next_scales) else None    # cat BlCs to BLC, this should be float32
    
    # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
    def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
        HW = self.v_patch_nums[-1]
        if si != SN-1:
            h = self.quant_resi[si/(SN-1)](F.interpolate(h_BChw, size=(HW, HW), mode='bicubic'))     # conv after upsample
            f_hat.add_(h)
            return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si+1], self.v_patch_nums[si+1]), mode='area')
        else:
            h = self.quant_resi[si/(SN-1)](h_BChw)
            f_hat.add_(h)
            return f_hat, f_hat


class Phi(nn.Conv2d):
    def __init__(self, embed_dim, quant_resi):
        ks = 3
        super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks//2)
        self.resi_ratio = abs(quant_resi)
    
    def forward(self, h_BChw):
        return h_BChw.mul(1-self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)


class PhiShared(nn.Module):
    def __init__(self, qresi: Phi):
        super().__init__()
        self.qresi: Phi = qresi
    
    def __getitem__(self, _) -> Phi:
        return self.qresi


class PhiPartiallyShared(nn.Module):
    def __init__(self, qresi_ls: nn.ModuleList):
        super().__init__()
        self.qresi_ls = qresi_ls
        K = len(qresi_ls)
        self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
    
    def __getitem__(self, at_from_0_to_1: float) -> Phi:
        return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
    
    def extra_repr(self) -> str:
        return f'ticks={self.ticks}'


class PhiNonShared(nn.ModuleList):
    def __init__(self, qresi: List):
        super().__init__(qresi)
        # self.qresi = qresi
        K = len(qresi)
        self.ticks = np.linspace(1/3/K, 1-1/3/K, K) if K == 4 else np.linspace(1/2/K, 1-1/2/K, K)
    
    def __getitem__(self, at_from_0_to_1: float) -> Phi:
        return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
    
    def extra_repr(self) -> str:
        return f'ticks={self.ticks}'
