import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat, reduce, pack, unpack

# from vector_quantize_pytorch import ResidualVQ

#Borrow from vector_quantize_pytorch

def log(t, eps = 1e-20):
    return torch.log(t.clamp(min = eps))

def gumbel_noise(t):
    noise = torch.zeros_like(t).uniform_(0, 1)
    return -log(-log(noise))

def gumbel_sample(
    logits,
    temperature = 1.,
    stochastic = False,
    dim = -1,
    training = True
):

    if training and stochastic and temperature > 0:
        sampling_logits = (logits / temperature) + gumbel_noise(logits)
    else:
        sampling_logits = logits

    ind = sampling_logits.argmax(dim = dim)

    return ind
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-6):
        super().__init__()
        self.eps = eps
        self.scale = nn.Parameter(torch.ones(dim))

    def forward(self, x):
        # x: (..., dim)
        # Calculate RMS
        rms = torch.sqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
        x_normalized = x / rms
        return x_normalized * self.scale
class QuantizeEMAReset(nn.Module):
    def __init__(self, nb_code, code_dim, args):
        super(QuantizeEMAReset, self).__init__()
        self.nb_code = nb_code
        self.code_dim = code_dim
        self.mu = args.mu  ##TO_DO
        self.args=args
        if self.args.q_self_attn=="SelfAttn":
            self.rms_norm=RMSNorm(code_dim)
        self.reset_codebook()
        # self.register_buffer('code_usage', torch.zeros(nb_code))
        # self.register_buffer('total_usage_count', torch.tensor(0))
        # self.reset_usage_stats()

    def reset_codebook(self):
        self.init = False
        self.code_sum = None
        self.code_count = None
        self.register_buffer('codebook', torch.zeros(self.nb_code, self.code_dim, requires_grad=False).cuda())
    def reset_usage_stats(self):
        """Reset usage statistics"""
        self.code_usage.zero_()
        self.total_usage_count.zero_()
    def _tile(self, x):
        nb_code_x, code_dim = x.shape
        if nb_code_x < self.nb_code:
            n_repeats = (self.nb_code + nb_code_x - 1) // nb_code_x
            std = 0.01 / np.sqrt(code_dim)
            out = x.repeat(n_repeats, 1)
            out = out + torch.randn_like(out) * std
        else:
            out = x
        return out
    def get_codebook_stats(self):
        """Get codebook utilization statistics"""
        if self.total_usage_count == 0:
            return {
                'active_codes': 0,
                'utilization_ratio': 0.0,
                'usage_histogram': self.code_usage.cpu(),
                'normalized_usage': torch.zeros_like(self.code_usage)
            }

        # Calculate statistics
        active_codes = (self.code_usage > 0).sum().item()
        normalized_usage = self.code_usage.float() / self.total_usage_count
        utilization_ratio = active_codes / self.nb_code

        return {
            'active_codes': active_codes,
            'utilization_ratio': utilization_ratio,
            'usage_histogram': self.code_usage.cpu(),
            'normalized_usage': normalized_usage.cpu(),
            'total_usage': self.total_usage_count.item()
        }
    @torch.no_grad()
    def update_usage_stats(self, code_idx):
        """Update codebook usage statistics"""
        unique_codes, counts = torch.unique(code_idx, return_counts=True)
        self.code_usage[unique_codes] += counts
        self.total_usage_count += code_idx.numel()
    def init_codebook(self, x):
        out = self._tile(x)
        self.codebook = out[:self.nb_code]
        self.code_sum = self.codebook.clone()
        self.code_count = torch.ones(self.nb_code, device=self.codebook.device)
        self.init = True

    def quantize(self, x, sample_codebook_temp=0.):
        # N X C -> C X N
        k_w = self.codebook.t()
        # x: NT X C
        # NT X N
        distance = torch.sum(x ** 2, dim=-1, keepdim=True) - \
                   2 * torch.matmul(x, k_w) + \
                   torch.sum(k_w ** 2, dim=0, keepdim=True)  # (N * L, b)

        # code_idx = torch.argmin(distance, dim=-1)

        code_idx = gumbel_sample(-distance, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training)

        return code_idx
    def quantize_self_attn(self, x, sample_codebook_temp=0.):
        # N X C -> C X N
        # 对输入和codebook进行RMSNorm
        x_normalized = self.rms_norm(x)  # (NT, C)
        codebook_normalized = self.rms_norm(self.codebook)  # (nb_code, C)
        
        # 计算余弦相似度
        similarity = torch.matmul(x_normalized, codebook_normalized.t())  # (NT, nb_code)
        # code_idx = torch.argmin(distance, dim=-1)

        code_idx = gumbel_sample(similarity, dim = -1, temperature = sample_codebook_temp, stochastic=True, training = self.training)

        return code_idx

    def dequantize(self, code_idx):
        x = F.embedding(code_idx, self.codebook)
        return x
    
    def get_codebook_entry(self, indices):
        return self.dequantize(indices).permute(0, 2, 1)

    @torch.no_grad()
    def compute_perplexity(self, code_idx):
        # Calculate new centres
        code_onehot = torch.zeros(self.nb_code, code_idx.shape[0], device=code_idx.device)  # nb_code, N * L
        code_onehot.scatter_(0, code_idx.view(1, code_idx.shape[0]), 1)

        code_count = code_onehot.sum(dim=-1)  # nb_code
        prob = code_count / torch.sum(code_count)
        perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))
        return perplexity

    @torch.no_grad()
    def update_codebook(self, x, code_idx):
        code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
        code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)

        code_sum = torch.matmul(code_onehot, x) # nb_code, c
        code_count = code_onehot.sum(dim=-1) # nb_code

        out = self._tile(x)
        code_rand = out[:self.nb_code]

        # Update centres
        self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
        self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count

        usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
        code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
        self.codebook = usage * code_update + (1-usage) * code_rand


        prob = code_count / torch.sum(code_count)
        perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))

        return perplexity

    def preprocess(self, x):
        # NCT -> NTC -> [NT, C]
        # x = x.permute(0, 2, 1).contiguous()
        # x = x.view(-1, x.shape[-1])
        x = rearrange(x, 'n c t -> (n t) c')
        return x

    def forward(self, x, return_idx=False, temperature=0.):
        N, width, T = x.shape

        x = self.preprocess(x)
        if self.training and not self.init:
            self.init_codebook(x)
        if self.args.q_self_attn=="SelfAttn":
            code_idx = self.quantize_self_attn(x, temperature)
        else:
            code_idx = self.quantize(x, temperature)
        # self.update_usage_stats(code_idx)
        x_d = self.dequantize(code_idx)
        if self.training:
            perplexity = self.update_codebook(x, code_idx)
        else:
            perplexity = self.compute_perplexity(code_idx)

        commit_loss = F.mse_loss(x, x_d.detach()) # It's right. the t2m-gpt paper is wrong on embed loss and commitment loss.

        # Passthrough
        x_d = x + (x_d - x).detach()

        # Postprocess
        x_d = x_d.view(N, T, -1).permute(0, 2, 1).contiguous()
        code_idx = code_idx.view(N, T).contiguous()
        # print(code_idx[0])
        if return_idx:
            return x_d, code_idx, commit_loss, perplexity
        return x_d, commit_loss, perplexity
    
class QuantizeEMA(QuantizeEMAReset):
    @torch.no_grad()
    def update_codebook(self, x, code_idx):
        code_onehot = torch.zeros(self.nb_code, x.shape[0], device=x.device) # nb_code, N * L
        code_onehot.scatter_(0, code_idx.view(1, x.shape[0]), 1)

        code_sum = torch.matmul(code_onehot, x) # nb_code, c
        code_count = code_onehot.sum(dim=-1) # nb_code

        # Update centres
        self.code_sum = self.mu * self.code_sum + (1. - self.mu) * code_sum
        self.code_count = self.mu * self.code_count + (1. - self.mu) * code_count

        usage = (self.code_count.view(self.nb_code, 1) >= 1.0).float()
        code_update = self.code_sum.view(self.nb_code, self.code_dim) / self.code_count.view(self.nb_code, 1)
        self.codebook = usage * code_update + (1-usage) * self.codebook

        prob = code_count / torch.sum(code_count)
        perplexity = torch.exp(-torch.sum(prob * torch.log(prob + 1e-7)))

        return perplexity
    
    
# class Args:
#     mu: float = 0.99

# def test_quantizer():
#     # Set random seed for reproducibility
#     torch.manual_seed(42)
    
#     # Create configuration
#     args = Args()
    
#     # Initialize quantizer
#     nb_code = 1024    # Number of codebook entries
#     code_dim = 256    # Dimension of each codebook vector
#     quantizer = QuantizeEMAReset(nb_code=nb_code, code_dim=code_dim, args=args)
    
#     # Create dummy input tensor
#     batch_size = 8
#     seq_length = 80
#     channels = 256
#     x = torch.randn(batch_size, channels, seq_length)
    
#     # Set to training mode
#     quantizer.train()
    
#     # Forward pass with different temperatures
#     print("Testing with different temperatures:")
#     for temp in [0.0, 0.5, 1.0]:
#         output, commit_loss, perplexity = quantizer(x, temperature=temp)
        
#         print(f"\nTemperature: {temp}")
#         print(f"Input shape: {x.shape}")
#         print(f"Output shape: {output.shape}")
#         print(f"Commitment loss: {commit_loss:.4f}")
#         print(f"Codebook perplexity: {perplexity:.2f}")
        
#     # Test with index return
#     output, indices, commit_loss, perplexity = quantizer(x, return_idx=True)
#     print("\nWith return_idx=True:")
#     print(f"Output shape: {output.shape}")
#     print(f"Indices shape: {indices.shape}")
    
#     # Evaluation mode
#     quantizer.eval()
#     with torch.no_grad():
#         output, commit_loss, perplexity = quantizer(x)
#         print("\nEvaluation mode:")
#         print(f"Commitment loss: {commit_loss:.4f}")
#         print(f"Codebook perplexity: {perplexity:.2f}")

# if __name__ == "__main__":
#     test_quantizer()