import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from typing import Tuple


class SequenceGaussianVectorQuantizer(nn.Module):
    """
    序列数据的向量量化器，使用固定的高斯函数波包作为codebook向量
    处理[B x S x D x H x W]形式的序列数据
    """

    def __init__(self,
                 num_embeddings: int,
                 embedding_dim: int,
                 sigma: float = 0.01):  # 高斯函数的标准差
        super(SequenceGaussianVectorQuantizer, self).__init__()
        self.K = num_embeddings  # 例如 512
        self.D = embedding_dim  # 例如 64
        self.sigma = sigma
        self.step = 8

        # 创建固定的高斯函数中心点
        centers = torch.linspace(-1, 1, steps=num_embeddings + 1)[:-1]  # 不包含1

        # 计算相邻中心点的距离
        center_distance = (centers[1] - centers[0]).item()
        print(f"相邻中心点的距离: {center_distance:.6f}")
        print(f"sigma/distance比率: {sigma / center_distance:.4f}")

        # 创建均匀分布在[-1,1)区间的D个点
        x = torch.linspace(-1, 1, embedding_dim + 1)[:-1]  # [D]
        centers = centers.view(-1, 1)  # [K, 1]
        x = x.view(1, -1)  # [1, D]

        # 计算周期性距离
        direct_diff = x - centers  # [K, D]
        periodic_diff = torch.where(direct_diff > 0,
                                    direct_diff - 2,
                                    direct_diff + 2)
        diff = torch.where(torch.abs(direct_diff) < torch.abs(periodic_diff),
                           direct_diff, periodic_diff)

        # 使用周期性距离计算高斯函数
        self.codebook = torch.exp(-(diff ** 2) / (2 * sigma ** 2))  # [K, D]

        # 注册为buffer而非parameter
        self.register_buffer('embedding_weight', self.codebook)

        print(f"已创建固定高斯codebook: {self.codebook.shape}")

    def forward(self, latents: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        """
        处理序列数据的向量量化，支持批次维度并行处理
        Args:
            latents: [B x S x D x H x W] 形状的张量
            label: [B x S x 4] 形状的张量，表示当前帧到下一帧的变化量
        Returns:
            quantized: [B x S x D x H x W] 量化后的张量
            vq_loss: 量化损失
        """
        B, S, D, H, W = latents.shape  # [B, S, D, H, W]
        N = H * W // 4

        device = latents.device

        # 重塑为可处理的形状
        latents = latents.view(B, S, D, 4, N).permute(0, 3, 4, 1, 2).contiguous()  # [B, 4, N, S, D]
        factor_latents = latents.view(B, 4, N, S * D)  # [B, 4, N, S*D]

        # 构造扩展的 embedding，去掉循环
        # label: [B, S, 4] -> 重排为 [B, 4, S]
        shift_amounts = (label.permute(0, 2, 1).long() * self.step)  # [B, 4, S]
        # 构造用于 roll 的索引：对于每个 shift_amount, 对 self.embedding_weight
        # 先生成 [K]的索引，并通过广播计算新索引: new_idx = (arange(K) - shift) % K
        k_idx = torch.arange(self.K, device=device).view(1, 1, 1, self.K)  # [1,1,1,K]
        # shift_amounts 扩展后: [B,4,S,1]
        rolled_indices = (k_idx - shift_amounts.unsqueeze(-1)) % self.K  # [B, 4, S, K]
        # 利用 rolled_indices 从 embedding_weight 取出新的 embedding，结果形状 [B,4,S,K,D]
        expanded_embed = self.embedding_weight[rolled_indices]  # [B,4,S,K,D]
        # 调整维度：将S与D合并之前先把[K]和[S]交换位置
        expanded_embed = expanded_embed.permute(0, 1, 3, 2, 4).contiguous()  # [B,4,K,S,D]
        # 最终拼接成 [B, 4, K, S*D]
        expanded_embedding = expanded_embed.view(B, 4, self.K, S * D)

        # 计算最近邻索引
        # factor_latents: [B,4,N,S*D] ；expanded_embedding: [B,4,K,S*D]
        A = factor_latents  # [B,4,N,S*D]
        B_expand = expanded_embedding  # [B,4,K,S*D]
        A_sq = (A ** 2).sum(dim=-1, keepdim=True)  # [B,4,N,1]
        B_sq = (B_expand ** 2).sum(dim=-1).unsqueeze(-2)  # [B,4,1,K]
        cross = 2 * torch.matmul(A, B_expand.transpose(-1, -2))  # [B,4,N,K]
        dist = A_sq + B_sq - cross  # [B,4,N,K]
        encoding_inds = dist.argmin(dim=-1)  # [B,4,N]

        # 根据索引从 expanded_embedding 中采样向量：利用 torch.gather，
        # 注意: expanded_embedding shape [B,4,K,S*D]，在dim=2上采样
        encoding_inds_exp = encoding_inds.unsqueeze(-1).expand(-1, -1, -1, S * D)  # [B,4,N,S*D]
        embedding_results = torch.gather(expanded_embedding, 2, encoding_inds_exp)  # [B,4,N,S*D]

        # embedding_results = embedding_results.view(B, N, 4, S * D).permute(0, 2, 1, 3)  # [B,4,N,S*D] error
        commitment_loss = F.mse_loss(embedding_results.detach(), factor_latents)  # [B,4,N,S*D]
        embedding_results = factor_latents + (embedding_results - factor_latents).detach()  # [B,4,N,S*D]

        embedding_results = embedding_results.view(B, 4, N, S, D)
        embedding_results = embedding_results.permute(0, 3, 4, 1, 2).contiguous().view(B, S, D, H, W)

        return embedding_results, commitment_loss

    def test_forward(self, latents_init: Tensor, label_init: Tensor, label_new: Tensor) -> Tensor:
        """
        测试阶段的前向传播，根据初始编码和位置，以及新位置生成新的量化结果，使用向量化操作
        Args:
            latents_init: [B x S x D x H x W] 初始的潜在表示
            label_init: [B x S x 4] 初始位置标签
            label_new: [B x S_new x 4] 新的位置标签
        Returns:
            quantized: [B x S_new x D x H x W] 调整后的量化表示
        """
        B, S, D, H, W = latents_init.shape
        _, S_new, _ = label_new.shape
        N = H * W // 4
        device = latents_init.device

        # 重塑初始潜在表示
        latents = latents_init.view(B, S, D, 4, N).permute(0, 3, 4, 1, 2).contiguous()  # [B,4,N,S,D]
        factor_latents = latents.view(B, 4, N, S * D)  # [B,4,N,S*D]

        # 步骤1: 根据 label_init 构造扩展 embedding
        shift_amounts_init = (label_init.permute(0, 2, 1).long() * self.step)  # [B,4,S]
        k_idx = torch.arange(self.K, device=device).view(1, 1, 1, self.K)  # [1,1,1,K]
        rolled_indices_init = (k_idx - shift_amounts_init.unsqueeze(-1)) % self.K  # [B,4,S,K]
        expanded_embed_init = self.embedding_weight[rolled_indices_init]  # [B,4,S,K,D]
        expanded_embed_init = expanded_embed_init.permute(0, 1, 3, 2, 4).contiguous()  # [B,4,K,S,D]
        expanded_embedding_init = expanded_embed_init.view(B, 4, self.K, S * D)

        # 步骤2: 根据初始 embedding 计算最近邻索引
        A = factor_latents  # [B,4,N,S*D]
        B_init = expanded_embedding_init  # [B,4,K,S*D]
        A_sq = (A ** 2).sum(dim=-1, keepdim=True)  # [B,4,N,1]
        B_sq = (B_init ** 2).sum(dim=-1).unsqueeze(-2)  # [B,4,1,K]
        cross = 2 * torch.matmul(A, B_init.transpose(-1, -2))  # [B,4,N,K]
        dist = A_sq + B_sq - cross  # [B,4,N,K]
        encoding_inds = dist.argmin(dim=-1)  # [B,4,N]

        # 步骤3: 根据 label_new 构造新的扩展 embedding
        shift_amounts_new = (label_new.permute(0, 2, 1).long() * self.step)  # [B,4,S_new]
        rolled_indices_new = (k_idx - shift_amounts_new.unsqueeze(-1)) % self.K  # [B,4,S_new,K]
        expanded_embed_new = self.embedding_weight[rolled_indices_new]  # [B,4,S_new,K,D]
        expanded_embed_new = expanded_embed_new.permute(0, 1, 3, 2, 4).contiguous()  # [B,4,K,S_new,D]
        expanded_embedding_new = expanded_embed_new.view(B, 4, self.K, S_new * D)  # [B,4,K,S_new*D]

        # 步骤4: 利用之前计算的 encoding_inds 从新 embedding 中采样
        # encoding_inds: [B,4,N] -> 扩展为 [B,4,N, S_new*D] 以便在维度2上采样
        encoding_inds_exp = encoding_inds.unsqueeze(-1).expand(-1, -1, -1, S_new * D)  # [B,4,N,S_new*D]
        embedding_results = torch.gather(expanded_embedding_new, 2, encoding_inds_exp)  # [B,4,N,S_new*D]
        embedding_results = embedding_results.view(B, 4, N, S_new, D)
        embedding_results = embedding_results.permute(0, 3, 4, 1, 2).contiguous().view(B, S_new, D, H, W)

        return embedding_results


class SequenceVectorQuantizer(nn.Module):
    """
    基于标准可训练codebook的序列向量量化器。
    处理 [B x S x D x H x W] 形式的序列数据，支持每个序列匹配一个shift后的code序列。
    """

    def __init__(self, num_embeddings: int, embedding_dim: int, step: int = 1, eta=0.25):
        super().__init__()
        self.K = num_embeddings  # 例如512
        self.D = embedding_dim  # 例如64
        self.step = step
        self.eta = eta
        # 初始化为可训练的codebook: [K, D]
        self.codebook = nn.Parameter(torch.randn(self.K, self.D) * 0.1)  # 或使用更复杂初始化

    def forward(self, latents: Tensor, label: Tensor) -> Tuple[Tensor, Tensor]:
        """
        Args:
            latents: [B x S x D x H x W]
            label:   [B x S x 4]，shift标签
        Returns:
            quantized: [B x S x D x H x W]
            vq_loss: 含commitment + codebook loss
        """
        B, S, D, H, W = latents.shape
        N = H * W // 4
        device = latents.device

        # reshape
        latents = latents.view(B, S, D, 4, N).permute(0, 3, 4, 1, 2).contiguous()  # [B, 4, N, S, D]
        factor_latents = latents.view(B, 4, N, S * D)  # [B, 4, N, S*D]

        # compute shifted codebook per sequence
        shift_amounts = label.permute(0, 2, 1).long() * self.step  # [B, 4, S]
        k_idx = torch.arange(self.K, device=device).view(1, 1, 1, self.K)  # [1, 1, 1, K]
        rolled_indices = (k_idx - shift_amounts.unsqueeze(-1)) % self.K  # [B, 4, S, K]

        # gather shifted codebooks: [B, 4, S, K, D]
        expanded_embed = self.codebook[rolled_indices]  # [B, 4, S, K, D]
        expanded_embed = expanded_embed.permute(0, 1, 3, 2, 4).contiguous()  # [B, 4, K, S, D]
        expanded_embed = expanded_embed.view(B, 4, self.K, S * D)  # [B, 4, K, S*D]

        # compute distances and nearest codes
        A = factor_latents  # [B, 4, N, S*D]
        B_expand = expanded_embed  # [B, 4, K, S*D]
        A_sq = (A ** 2).sum(dim=-1, keepdim=True)  # [B, 4, N, 1]
        B_sq = (B_expand ** 2).sum(dim=-1).unsqueeze(-2)  # [B, 4, 1, K]
        cross = 2 * torch.matmul(A, B_expand.transpose(-1, -2))  # [B, 4, N, K]
        dist = A_sq + B_sq - cross  # [B, 4, N, K]
        encoding_inds = dist.argmin(dim=-1)  # [B, 4, N]

        # gather encoding result
        encoding_inds_exp = encoding_inds.unsqueeze(-1).expand(-1, -1, -1, S * D)  # [B, 4, N, S*D]
        quantized = torch.gather(expanded_embed, 2, encoding_inds_exp)  # [B, 4, N, S*D]

        # VQ loss（含commitment loss + codebook loss）
        commitment_loss = F.mse_loss(quantized.detach(), factor_latents)
        codebook_loss = F.mse_loss(quantized, factor_latents.detach())
        vq_loss = commitment_loss + self.eta * codebook_loss

        # Straight-through estimator
        quantized = factor_latents + (quantized - factor_latents).detach()

        # reshape回原始格式
        quantized = quantized.view(B, 4, N, S, D).permute(0, 3, 4, 1, 2).contiguous()
        quantized = quantized.view(B, S, D, H, W)

        return quantized, vq_loss


class SequenceFSQ(nn.Module):
    def __init__(self, num, d, a_dim, device):
        super().__init__()
        assert (num * d) % 2 == 0, "num * d 必须是偶数"
        assert sum(a_dim) == (num * d) // 2, "sum(a_dim) 必须等于 num * d / 2"

        self.device = device
        self.num = num
        self.d = d
        self.a_dim = a_dim
        self.n_rot = (num * d) // 2

        # 初始化 theta ∈ [0, 2π]（可换成随机初始化）
        # theta = torch.ones(self.n_rot, device=self.device) * 2 * math.pi * (4 / 25)
        theta = torch.rand(self.n_rot, device=self.device) * 2 * math.pi
        self.register_buffer('theta', theta)

        # 扩展索引，用于 label -> n_rot 映射
        expanded_index = torch.repeat_interleave(
            torch.arange(len(a_dim), device=self.device),
            torch.tensor(a_dim, device=self.device)
        )
        self.register_buffer('expanded_index', expanded_index)

    def forward(self, x, y):
        """
        x: (B, S, num, d)
        y: (B, S, d_a)
        """
        B, S, num, d = x.shape
        assert num == self.num and d == self.d

        x = x.view(B, S, self.n_rot, 2)  # => (B, S, n_rot, 2)
        x = x / (torch.norm(x, dim=-1, keepdim=True) + 1e-8)

        # 扩展 label 到 (B, S, n_rot)
        y_expanded = y[:, :, self.expanded_index]

        # === Step 1: 逆时针旋转 ===
        angles = y_expanded * self.theta.view(1, 1, -1)
        cos = torch.cos(angles)
        sin = torch.sin(angles)

        x1, x2 = x[..., 0], x[..., 1]
        x_rot = torch.stack([
            x1 * cos - x2 * sin,
            x1 * sin + x2 * cos
        ], dim=-1)

        # === Step 2: 均值后归一化 ===
        x_ = x_rot.mean(dim=1)
        x_ = x_ / (torch.norm(x_, dim=-1, keepdim=True) + 1e-8)

        # === Step 3: 顺时针旋转回来 S 次 ===
        x_ = x_.unsqueeze(1).expand(-1, S, -1, -1)
        angles_inv = -angles
        cos_inv = torch.cos(angles_inv)
        sin_inv = torch.sin(angles_inv)

        x1_, x2_ = x_[..., 0], x_[..., 1]
        x_restored = torch.stack([
            x1_ * cos_inv - x2_ * sin_inv,
            x1_ * sin_inv + x2_ * cos_inv
        ], dim=-1)
        x_restored = x_restored / (torch.norm(x_restored, dim=-1, keepdim=True) + 1e-8)

        # x_restored = torch.round(x_restored * 10) / 10

        # === Step 4: 计算 commit_loss + straight-through estimator ===
        commit_loss = F.mse_loss(x_restored.detach(), x, reduction='mean')
        x_st = x + (x_restored - x).detach()

        return x_st, x, commit_loss

