import torch
import torch.nn as nn
import torch.nn.functional as F


class PerDimSoftQuantizer(nn.Module):
    def __init__(self, num_latents, num_values_per_latent, temperature=1.0, min_temperature=0.1, anneal_rate=2e-2):
        """
        Soft quantizer where each dimension has its own codebook (1D)
        Args:
            num_latents: number of latent dimensions
            num_values_per_latent: int or list of int, number of discrete values per latent dimension
            temperature: initial softmax temperature
            min_temperature: the lowest temperature allowed during annealing
            anneal_rate: exponential decay rate of temperature
        """
        super().__init__()
        self.num_latents = num_latents
        self.min_temperature = min_temperature
        self.anneal_rate = anneal_rate

        if isinstance(num_values_per_latent, int):
            self.num_values_per_latent = [num_values_per_latent] * num_latents
        else:
            assert len(num_values_per_latent) == num_latents
            self.num_values_per_latent = num_values_per_latent

        # Register the 1D codebooks as learnable parameters
        self.codebooks = nn.ParameterList()
        for n_values in self.num_values_per_latent:
            values = torch.linspace(0, 1, steps=n_values)  # shape: (K,)
            self.codebooks.append(nn.Parameter(values, requires_grad=True))

        # Register temperature as a buffer so it updates but not as a trainable param
        self.register_buffer('temperature', torch.tensor(temperature))

    def anneal_temperature(self):
        new_temp = max(self.min_temperature, self.temperature * (1.0 - self.anneal_rate))
        self.temperature = self.temperature.new_tensor(new_temp)

    def forward(self, x):
        """
        Args:
            x: (B, D) input tensor (continuous latent vector)
        Returns:
            z_q: (B, D) soft quantized latent
            weights: list of (B, K_i) softmax weights per dimension
        """
        B, D = x.shape
        assert D == self.num_latents

        z_q_list = []
        weights_list = []

        for d in range(D):
            values = self.codebooks[d]  # shape: (K,)
            x_d = x[:, d].unsqueeze(1)  # shape: (B, 1)
            values_d = values.unsqueeze(0)  # shape: (1, K)

            # Compute negative distance
            distances = -(x_d - values_d).abs()  # (B, K)
            weights = F.softmax(distances / self.temperature, dim=-1)  # (B, K)

            z_q_d = torch.sum(weights * values_d, dim=1)  # (B,)
            z_q_list.append(z_q_d)
            weights_list.append(weights)

        z_q = torch.stack(z_q_list, dim=1)  # (B, D)
        recon_loss = F.mse_loss(z_q, x, reduction='mean')  # scalar loss

        return z_q, weights_list,recon_loss