import logging
from typing import Union

import numpy as np
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn

from mGPT.utils.tensors import collate_tensors

try:
    from torch.nn.utils.parametrizations import weight_norm
except ImportError:
    logging.getLogger(__name__).warning(
        "Using torch.nn.utils.weight_norm instead of torch.nn.utils.parametrizations due to PyTorch version"
    )
    from torch.nn.utils import weight_norm


def log(t, eps=1e-20):
    return torch.log(t.clamp(min=eps))


def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))


def gumbel_sample(logits, temperature=1.0, stochastic=False, dim=-1, training=True):
    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits
    ind = sampling_logits.argmax(dim=dim)
    return ind


def WNConv1d(*args, **kwargs):
    return weight_norm(nn.Conv1d(*args, **kwargs)).float()


def WNConvTranspose1d(*args, **kwargs):
    return weight_norm(nn.ConvTranspose1d(*args, **kwargs)).float()


class VectorQuantize(nn.Module):
    """
    Implementation of VQ similar to Karpathy's repo:
    https://github.com/karpathy/deep-vector-quantization
    Additionally uses following tricks from Improved VQGAN
    (https://arxiv.org/pdf/2110.04627.pdf):
        1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space
            for improved codebook usage
        2. l2-normalized codes: Converts euclidean distance to cosine similarity which
            improves training stability
        3. EMA Reset: Exponential Moving Average to reset codebook entries
    """

    def __init__(
        self,
        input_dim: int,
        codebook_size: int,
        codebook_dim: int,
        mu: float = 0.99,
        reset_threshold: float = 0.1,
        dropout: float = 0.0,
    ):
        super().__init__()

        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.mu = mu
        self.reset_threshold = reset_threshold
        self.dropout = dropout
        self.init = False
        self.code_sum = None
        self.code_count = None

        self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1)
        self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1)

        self.register_buffer(
            "codebook",
            torch.zeros(
                self.codebook_size, self.codebook_dim, requires_grad=False
            ).cuda(),
        )

    def _tile(self, x):
        nb_code_x, code_dim = x.shape
        if nb_code_x < self.codebook_size:
            n_repeats = (self.codebook_size + nb_code_x - 1) // nb_code_x
            std = 0.01 / np.sqrt(code_dim)
            out = x.repeat(n_repeats, 1)
            out = out + torch.randn_like(out) * std
        else:
            out = x
        return out

    def init_codebook(self, x):
        out = self._tile(x)
        self.codebook = out[: self.codebook_size]
        self.code_sum = self.codebook.clone()
        self.code_count = torch.ones(self.codebook_size, device=self.codebook.device)
        self.init = True

    def forward(self, z):
        """Quantized the input tensor using a fixed codebook and returns
        the corresponding codebook vectors

        Parameters
        ----------
        z : Tensor[B x D x T]

        Returns
        -------
        Tensor[B x D x T]
            Quantized continuous representation of input
        Tensor[1]
            Commitment loss to train encoder to predict vectors closer to codebook
            entries
        Tensor[1]
            Codebook loss to update the codebook
        Tensor[B x T]
            Codebook indices (quantized discrete representation of input)
        Tensor[B x D x T]
            Projected latents (continuous representation of input before quantization)
        """

        # Factorized codes (ViT-VQGAN) Project input into low-dimensional space
        z_e = self.in_proj(z)  # z_e : (B x D x T)

        z_q, indices, perplexity = self.decode_latents(z_e)

        commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2])

        z_q = (
            z_e + (z_q - z_e).detach()
        )  # noop in forward pass, straight-through gradient estimator in backward pass

        z_q = self.out_proj(z_q)

        return z_q, commitment_loss, indices, z_e, perplexity

    def embed_code(self, embed_id):
        return F.embedding(embed_id, self.codebook)

    def decode_code(self, embed_id):
        return self.embed_code(embed_id).transpose(1, 2)

    def decode_latents(self, latents, sample_codebook_temp=0.0):
        x = rearrange(latents, "b d t -> (b t) d")

        if self.training and not self.init:
            self.init_codebook(x)

        k_w = self.codebook.t()
        distance = (
            torch.sum(x**2, dim=-1, keepdim=True)
            - 2 * torch.matmul(x, k_w)
            + torch.sum(k_w**2, dim=0, keepdim=True)
        )  # (N * L, b)

        code_idx = gumbel_sample(
            -distance,
            dim=-1,
            temperature=sample_codebook_temp,
            stochastic=True,
            training=self.training,
        )

        if self.training and self.dropout > 0:
            mask = torch.rand_like(code_idx.float()) > self.dropout
            code_idx = code_idx * mask + torch.randint(
                0,
                self.codebook_size,
                code_idx.shape,
                device=code_idx.device,
            ) * (~mask)

        indices = rearrange(code_idx, "(b t) -> b t", b=latents.size(0))
        z_q = self.decode_code(indices)

        if self.training:
            perplexity = self.update_codebook(x, code_idx)
        else:
            perplexity = torch.tensor(0.0, device=x.device)

        return z_q, indices, perplexity

    @torch.no_grad()
    def update_codebook(self, x, code_idx):
        code_onehot = torch.zeros(
            self.codebook_size, x.shape[0], device=x.device
        )  # nb_code, N * L
        code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)

        code_sum = torch.matmul(code_onehot, x)  # nb_code, c
        code_count = code_onehot.sum(dim=-1)  # nb_code

        out = self._tile(x)
        code_rand = out[: self.codebook_size]

        # Update centres
        self.code_sum = self.mu * self.code_sum + (1.0 - self.mu) * code_sum
        self.code_count = self.mu * self.code_count + (1.0 - self.mu) * code_count

        usage = (self.code_count.view(self.codebook_size, 1) >= 1.0).float()
        code_update = self.code_sum.view(
            self.codebook_size, self.codebook_dim
        ) / self.code_count.view(self.codebook_size, 1)
        self.codebook = usage * code_update + (1 - usage) * code_rand

        prob = code_count / torch.sum(code_count)
        perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))

        return perplexity

    def quantize(self, z, return_latent=False, sample_codebook_temp=0.5, **kwargs):
        z_e = self.in_proj(z)
        z_q, indices, _ = self.decode_latents(
            z_e, sample_codebook_temp=sample_codebook_temp
        )
        z_q = self.out_proj(z_q)

        if return_latent:
            return indices, z_q
        return indices

    def dequantize(self, indices, **kwargs):
        z_q = self.decode_code(indices)
        z_q = self.out_proj(z_q)
        return z_q


class MultiscaleVectorQuantize11(nn.Module):
    """Follows Algorithm in https://arxiv.org/pdf/2404.02905.pdf"""

    def __init__(
        self,
        input_dim: int = 512,
        v_lengths: list[int] = 9,
        beta: float = 0.25,
        codebook_size: int = 1024,
        padding_length: int = 1024,
        codebook_dim: Union[int, list] = 8,
        quantizer_dropout: float = 0.0,
        codebook_dropout: float = 0.0,
        shared_codebook=False,
    ):
        super().__init__()

        self.beta = beta
        self.v_lengths = v_lengths

        n_codebooks = len(v_lengths)

        if isinstance(codebook_dim, int):
            codebook_dim = [codebook_dim, codebook_dim]
        elif len(codebook_dim) == 1:
            codebook_dim = [codebook_dim[0], codebook_dim[0]]
        elif len(codebook_dim) != 2:
            raise ValueError(
                f"codebook_dim {codebook_dim} should be either int or list of length 2"
            )

        self.n_codebooks = n_codebooks
        self.codebook_dim = codebook_dim
        self.codebook_size = codebook_size

        if shared_codebook:
            quant = VectorQuantize(
                input_dim,
                codebook_size,
                codebook_dim=codebook_dim[-1],
                dropout=codebook_dropout,
            )
            self.quantizers = nn.ModuleList([quant for _ in range(n_codebooks)])
        else:
            self.quantizers = nn.ModuleList(
                [
                    VectorQuantize(
                        input_dim,
                        codebook_size,
                        codebook_dim=codebook_dim[-1],
                        dropout=codebook_dropout,
                    )
                    for i in range(n_codebooks)
                ]
            )

        self.quantizer_dropout = quantizer_dropout

    def forward(
        self,
        z,
        z_buffer,
        n_quantizers: int = None,
        return_all_codes=False,
        sample_codebook_temp=0.5,
    ):
        """
        Quantized the input tensor using a fixed set of `n` codebooks and returns the corresponding codebook vectors Parameters
        ----------
        z : Tensor[B x D x T]
        n_quantizers : int, optional
            No. of quantizers to use
            (n_quantizers < self.n_codebooks ex: for quantizer dropout)
            Note: if `self.quantizer_dropout` is True, this argument is ignored when in training mode, and a random number of quantizers is used.
        Returns
        -------
        dict
            A dictionary with the following keys:

            "z" : Tensor[B x D x T]
                Quantized continuous representation of input
            "codes" : Tensor[B x N x T]
                Codebook indices for each codebook
                (quantized discrete representation of input)
            "latents" : Tensor[B x N*D x T]
                Projected latents (continuous representation of input before quantization)
            "vq/commitment_loss" : Tensor[1]
                Commitment loss to train encoder to predict vectors closer to codebook
                entries
            "vq/codebook_loss" : Tensor[1]
                Codebook loss to update the codebook
        """

        residual = z
        z_q = torch.zeros_like(z)

        codebook_indices = []
        all_perplexity = []
        commitment_loss = []

        if n_quantizers is None:
            n_quantizers = self.n_codebooks

        if self.training:
            n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1
            dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],))
            n_dropout = int(z.shape[0] * self.quantizer_dropout)
            n_quantizers[:n_dropout] = dropout[:n_dropout]
            n_quantizers = n_quantizers.to(z.device)

        for i, quantizer in enumerate(self.quantizers):
            if self.training is False and i >= n_quantizers:
                break

            z_q_i, commitment_loss_i, indices_i, _, perplexity_i = quantizer(residual)

            # Create mask to apply quantizer dropout
            mask = (
                torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers
            )

            z_q = z_q + z_q_i * mask[:, None, None]
            residual = residual - z_q_i

            commitment_loss.append((commitment_loss_i * mask).mean())

            codebook_indices.append(indices_i)
            all_perplexity.append(perplexity_i)

        all_losses = sum(commitment_loss) / len(commitment_loss)
        all_perplexity = sum(all_perplexity) / len(all_perplexity)

        # codes = torch.stack(codebook_indices, dim=1)
        codes = torch.cat(codebook_indices, dim=1)

        ret = (z_q, all_losses, all_perplexity)
        if return_all_codes:
            ret = (*ret, codes)

        return ret

    def quantize(self, x, z_buffer, sample_codebook_temp=0.5):
        """
        parameters:
            x: Tensor[B x D x T]
        return:
            codes: Tensor[B x T x N]
        """

        residual = x
        codebook_indices = []

        for i, quantizer in enumerate(self.quantizers):
            indices_i, z_q_i = quantizer.quantize(
                residual,
                return_latent=True,
                sample_codebook_temp=sample_codebook_temp,
            )
            residual = residual - z_q_i
            codebook_indices.append(indices_i)

        codes = (
            collate_tensors(codebook_indices, fill_value=-1)
            .permute(1, 2, 0)
            .contiguous()
        )

        return codes

    def dequantize(self, codes, z_buffer):
        """
        parameters:
            codes: Tensor[B x T x N]
        return:
            z_q: Tensor[B x T x D]
        """

        z_q = 0.0

        mask = codes == -1.0
        codes = codes.masked_fill(mask, 0)

        for i, quantizer in enumerate(self.quantizers):
            codes_i = codes[..., i]

            z_q_i = quantizer.dequantize(codes_i).contiguous()
            z_q_i = z_q_i.masked_fill(mask[..., i], 0.0)

            z_q = z_q + z_q_i

        return z_q
