# soundstream_vq.py
# Simple vector-quantizer (VQ) module to emulate SoundStream discrete tokenization.
# Not a full SoundStream implementation: this is a VQ-VAE-style quantizer that maps continuous frames -> codebook ids.
# It supports commit loss calculation and embed lookup.

from typing import Tuple, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


class VectorQuantizer(nn.Module):
    """
    VQ layer (EMA or simple):
    - codebook: (K, D)
    - forward: input (B, T, D) -> quantized (B, T, D), ids (B, T)
    Returns quantized + commitment_loss
    """

    def __init__(self, num_codes: int = 512, code_dim: int = 256, beta: float = 0.25, ema: bool = False):
        super().__init__()
        self.num_codes = num_codes
        self.code_dim = code_dim
        self.beta = beta
        self.ema = ema
        self.codebook = nn.Parameter(torch.randn(num_codes, code_dim) * 0.01)
        # EMA stats
        if self.ema:
            self.register_buffer("ema_cluster_size", torch.zeros(num_codes))
            self.register_buffer("ema_codebook", torch.randn(num_codes, code_dim) * 0.01)
            self.decay = 0.99

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        x: (B, T, D) where D == code_dim (or projected)
        returns:
          quantized: (B, T, D)
          ids: (B, T) long
          loss: scalar commitment + codebook loss
        """
        B, T, D = x.shape
        assert D == self.code_dim, "Input dim must match code_dim"

        # Flatten input for nearest neighbor search
        flat_x = x.reshape(-1, D)  # (B*T, D)

        # Compute distances to codebook entries
        # using (x - e)^2 = x^2 - 2 x e + e^2
        codebook = self.codebook  # (K, D)
        distances = (
            flat_x.pow(2).sum(1, keepdim=True)
            - 2 * torch.matmul(flat_x, codebook.t())
            + codebook.pow(2).sum(1).unsqueeze(0)
        )  # (B*T, K)

        # nearest neighbor
        encoding_indices = torch.argmin(distances, dim=1)  # (B*T,)
        encodings = F.one_hot(encoding_indices, num_classes=self.num_codes).float()  # (B*T, K)

        # quantized vectors
        quantized_flat = torch.matmul(encodings, codebook)  # (B*T, D)
        quantized = quantized_flat.view(B, T, D)

        # compute commitment loss
        e_latent_loss = F.mse_loss(quantized.detach(), x)
        q_latent_loss = F.mse_loss(quantized, x.detach())
        loss = q_latent_loss + self.beta * e_latent_loss

        # optionally EMA updates (not training safe in all contexts)
        if self.ema and self.training:
            with torch.no_grad():
                cluster_size = encodings.sum(0)  # (K,)
                dw = torch.matmul(encodings.t(), flat_x)  # (K, D)
                self.ema_cluster_size = self.ema_cluster_size * self.decay + (1 - self.decay) * cluster_size
                self.ema_codebook = self.ema_codebook * self.decay + (1 - self.decay) * dw
                # normalize
                n = self.ema_cluster_size.sum()
                cluster_size = (self.ema_cluster_size + 1e-5) / (n + self.num_codes * 1e-5) * n
                self.codebook.data = self.ema_codebook / cluster_size.unsqueeze(1)

        # Straight-through estimator: quantized + (x - x.detach())
        quantized_st = x + (quantized - x).detach()
        ids = encoding_indices.view(B, T)
        return quantized_st, ids, loss
