from typing import List, Optional

import torch
from torch import Tensor

from .base import BaseVQVAE


class VQVAE(BaseVQVAE):
    """VQ-VAE model."""

    def __init__(self, **kwargs):
        defaults = {
            "encoder_type": "resnet1d",
            "decoder_type": "resnet1d",
            "quantizer": "ema_reset",
            "code_num": 512,
            "code_dim": 512,
            "output_emb_width": 512,
            "down_t": 3,
            "stride_t": 2,
            "layers": 0,
            "width": 512,
            "depth": 3,
            "dilation_growth_rate": 3,
            "norm": None,
            "activation": "relu",
            "num_quantizers": 1,
            "shared_codebook": False,
            "heads": 1,
            "quantize_dropout": False,
            "quantize_dropout_prob": 0.2,
            "quantize_dropout_cutoff_index": 0,
            "quantize_dropout_multiple_of": 1,
            "accept_image_fmap": False,
            "kmeans_init": True,
            "kmeans_iters": 10,
            "regroup": False,
            "pad_length": 128,
            "pad_batch": 1024,
            "max_length": 2048,
            "beta": 0.25,
            "share_quant_resi": 4,
            "v_lengths": (1, 2, 3, 4),
        }

        updated_kwargs = {**defaults, **kwargs}
        super().__init__(**updated_kwargs)

        self._init_encoder()
        self._init_decoder()
        self._init_quantizer()

        z_buffer = torch.zeros(
            self.hparams.pad_batch,
            self.hparams.output_emb_width,
            self.hparams.pad_length,
        )
        self.register_buffer("z_buffer", z_buffer)

    def forward(self, features: Tensor, infer_dec: bool = False):
        bs = features.shape[0]

        # Preprocess
        x_in = self.preprocess(features)

        # Encode
        x_encoder = self.encoder(x_in)

        # quantization
        x_quantized, loss, perplexity = self.quantizer(x_encoder, self.z_buffer)

        # buffer
        z_all = torch.cat([self.z_buffer[:bs], x_encoder], dim=-1)
        self.z_buffer[:bs] = z_all[:bs, :, -self.hparams.pad_length :]

        # decoder
        if hasattr(self.decoder, "inference") and infer_dec:
            x_decoder = self.decoder.inference(x_quantized)
        else:
            x_decoder = self.decoder(x_quantized)

        # Postprocess
        x_out = self.postprocess(x_decoder)

        return x_out, loss, perplexity

    def encode(
        self,
        features: Tensor,
    ) -> Tensor:
        bs, T, _ = features.shape
        x_in = self.preprocess(features)

        if hasattr(self.encoder, "inference"):
            x_encoder = self.encoder.inference(x_in)
        else:
            x_encoder = self.encoder(x_in)

        if self.hparams.quantizer_type not in ["rvq_guo", "rvq_lfq", "mvq"]:
            x_encoder = self.postprocess(x_encoder)
            x_encoder = x_encoder.contiguous().view(-1, x_encoder.shape[-1])  # (NT, C)

        code_idx = self.quantizer.quantize(x_encoder, self.z_buffer)

        z_all = torch.cat([self.z_buffer[:bs], x_encoder], dim=-1)
        self.z_buffer[:bs] = z_all[:bs, :, -self.hparams.pad_length :]

        return code_idx, None

    def decode(self, z: List[Tensor], memory: bool = False):
        N = len(z)

        if self.hparams.quantizer_type not in ["rvq_guo", "rvq_lfq", "mvq"]:
            x_d = self.quantizer.dequantize(z)
            x_d = x_d.view(N, -1, self.hparams.code_dim).permute(0, 2, 1).contiguous()
        elif self.hparams.quantizer_type == "rvq_guo":
            x_d = self.quantizer.get_codes_from_indices(z)
            x_d = x_d.sum(dim=0).permute(0, 2, 1)
        elif self.hparams.quantizer_type == "rvq_lfq":
            x_d = self.quantizer.get_output_from_indices(z).permute(0, 2, 1)
        elif self.hparams.quantizer_type == "mvq":
            x_d = self.quantizer.dequantize(z, z_buffer=self.z_buffer)

        z_all = torch.cat([self.z_buffer[:N], x_d], dim=-1)
        self.z_buffer[:N] = z_all[:, :, -self.hparams.pad_length :]

        if hasattr(self.decoder, "inference"):
            if memory:
                x_decoder = self.decoder.inference(x_d[..., 2:], x_d[..., :2])
            else:
                x_decoder = self.decoder.inference(x_d)
        else:
            x_decoder = self.decoder(x_d)

        x_out = self.postprocess(x_decoder)

        return x_out

    def reset_buffer(self, target_module: Optional[str] = None, z_buffer: bool = True):
        if z_buffer:
            self.z_buffer = torch.zeros(
                self.hparams.pad_batch,
                self.hparams.output_emb_width,
                self.hparams.pad_length,
                device=self.z_buffer.device,
            )
        super().reset_buffer(target_module)
