# calculate the center of the patch, then calcualte the positional embedding
from typing import Optional

from einops import rearrange
from timm.layers import DropPath, Mlp
from torch import nn
import torch.nn.functional as F
from einops.layers.torch import Rearrange
from torch.nn import MultiheadAttention

from chamfer_dist import ChamferDistanceL1
from model.quantize import Quantizer as VectorQuantizer
import torch
from timm.models.vision_transformer import PatchEmbed, Block, LayerScale, Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
# 确保能从您的训练脚本中正确导入Locator模型类
from .train_locator import Locator


class AdjacencyExpert(nn.Module):
    """
    一个集成的、可插拔的专家模块。
    它接收一批patch的中心点坐标，并返回一个注意力偏置矩阵。
    """

    def __init__(self, locator_path, template_path):
        super().__init__()

        # --- 1. 加载并冻结“定位器”模型 ---
        print("正在加载专家模块: 定位器...")
        template_data = torch.load(template_path, weights_only=True)
        num_classes = template_data['lut'].shape[0]

        self.locator = Locator(num_standard_indices=num_classes)
        checkpoint = torch.load(locator_path, weights_only=True)

        # 2. 从检查点中，根据键 'model_state_dict' 提取出模型的 state_dict
        locator_state_dict = checkpoint['model_state_dict']

        # 3. 将提取出的、干净的 state_dict 加载到模型中
        self.locator.load_state_dict(locator_state_dict)
        self.locator.eval()  # 设置为评估模式

        # 冻结定位器的所有参数，它不参与MAE的训练
        for param in self.locator.parameters():
            param.requires_grad = False

        # --- 2. 加载查询表 (LUT) ---
        print("正在加载专家模块: 查询表...")
        self.register_buffer("hop_lookup_table", template_data['lut'])

        # --- 3. 创建可学习的“自适应偏置生成器” ---
        # 这个子模块将把“跳数”转换为最终的注意力偏置
        # 它的参数 A, sigma, B 是可学习的，将与MAE一同训练！
        self.bias_generator = self.AdaptiveBiasGenerator()
        print("专家模块加载完成。")

    @torch.no_grad()  # 整个定位和查询过程都不需要计算梯度
    def _get_hop_matrix(self, patch_positions):
        """
        私有方法：执行“定位+查询”来获取跳数矩阵。
        """
        # patch_positions 的 Shape: [batch_size, num_visible_patches, 3]
        batch_size, num_patches, _ = patch_positions.shape

        # 将所有patch的坐标展平，进行一次批处理预测
        flat_positions = patch_positions.view(-1, 3)

        # 步骤1: 定位器预测
        pred_logits = self.locator(flat_positions)
        pred_indices = torch.argmax(pred_logits, dim=-1)  # Shape: [batch_size * num_patches]

        # 步骤2: 查询跳数
        # a. 将预测的索引广播成两两配对的形式
        idx_a = pred_indices.view(batch_size, num_patches, 1)
        idx_b = pred_indices.view(batch_size, 1, num_patches)

        # b. 使用高级索引，直接从LUT中高效查询所有对的跳数
        hop_matrix = self.hop_lookup_table[idx_a, idx_b]  # Shape: [batch_size, num_patches, num_patches]

        return hop_matrix

    def forward(self, patch_positions):
        """
        主前向传播函数。
        """
        # 1. 获取跳数矩阵
        hop_matrix = self._get_hop_matrix(patch_positions)

        # 2. 使用可学习的偏置生成器，将跳数转换为注意力偏置
        attention_bias = self.bias_generator(hop_matrix)

        return attention_bias

    class AdaptiveBiasGenerator(nn.Module):
        """
        一个可学习的子模块，使用高斯函数将跳数映射到偏置值。
        """

        def __init__(self, initial_A=2.0, initial_sigma=2.0, initial_B=-8.0):
            super().__init__()
            # 将 A, sigma, B 注册为可学习的参数
            self.log_A = nn.Parameter(torch.tensor(initial_A).log())
            self.log_sigma = nn.Parameter(torch.tensor(initial_sigma).log())
            self.B = nn.Parameter(torch.tensor(initial_B))

        def forward(self, hop_matrix):
            # 使用exp来确保A和sigma始终为正
            A = self.log_A.exp()
            sigma = self.log_sigma.exp()

            # hop_matrix中可能存在-1（不可达），先将其处理为一个较大的数
            hop_matrix = hop_matrix.float().masked_fill(hop_matrix < 0, 1e9)

            bias = A * torch.exp(-hop_matrix.pow(2) / (2 * sigma.pow(2) + 1e-6)) + self.B
            return bias


class RelativeWindowMultiheadAttention(nn.Module):
    """
    一个带有“双重稳定器”（温度T和可学习偏置强度）的域注意力模块。
    这是V1 region_attn的直接、强化的替代品。
    """

    def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.1):
        super().__init__()
        assert embed_dim % num_heads == 0, "embed_dim 必须能被 num_heads 整除"

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // self.num_heads
        self.scale = self.head_dim ** -0.5

        # 标准的QKV和输出投影，与V1保持一致
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.attn_drop = nn.Dropout(dropout)
        self.proj_drop = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attention_bias: Optional[torch.Tensor]) -> torch.Tensor:
        """
        这个模块现在直接接收预先计算好的 attention_bias。
        """
        B, N, C = x.shape

        # qkv的计算保持不变
        qkv = self.qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.view(B, N, self.num_heads, self.head_dim).transpose(1, 2), qkv)

        attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale

        # --- 对接点 ---
        if attention_bias is not None:
            # 直接将预先计算好的偏置加上去
            # .unsqueeze(1) 用于广播到所有注意力头
            attn_scores = attn_scores + attention_bias.unsqueeze(1)

        attn_weights = F.softmax(attn_scores, dim=-1)
        attn_weights = self.attn_drop(attn_weights)

        output = torch.matmul(attn_weights, v).transpose(1, 2).reshape(B, N, C)
        output = self.proj_drop(self.out_proj(output))

        return output

class PerspectiveTransform(nn.Module):
    """
    视角转换层 (Perspective Transform Layer)
    将输入的标准化特征解耦为语义视角和几何视角。
    对应描述中的: X'_s = X' W_sem^T 和 X'_g = X' W_geo^T
    """
    def __init__(self, dim: int):
        super().__init__()
        # 语义视角投射
        self.sem_proj = nn.Linear(dim, dim)
        # 几何视角投射
        self.geo_proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> (torch.Tensor, torch.Tensor):
        x_sem = self.sem_proj(x)
        x_geo = self.geo_proj(x)
        return x_sem, x_geo

class ChannelModulation(nn.Module):
    """
    通道调制层 (Channel Modulation Layer)
    通过一个瓶颈结构的MLP为每个token生成调制权重，并对特征进行重校准。
    对应描述中的: Y_cal = Y ⊙ W_mod
    """
    def __init__(self, dim: int, reduction_ratio: int = 4, act_layer: nn.Module = nn.GELU):
        super().__init__()
        hidden_dim = dim // reduction_ratio
        self.mlp = nn.Sequential(
            nn.Linear(dim, hidden_dim, bias=True),
            act_layer(),
            nn.Linear(hidden_dim, dim, bias=True),
            nn.Sigmoid()
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 生成调制权重 W_mod
        mod_weights = self.mlp(x)
        # 逐元素相乘得到 Y_cal
        return x * mod_weights

class Block(nn.Module):
    def __init__(
            self,
            dim: int,
            num_heads: int,
            mlp_ratio: float = 4.,
            qkv_bias: bool = True, # V1中通常为True
            proj_drop: float = 0.,
            attn_drop: float = 0.,
            drop_path: float = 0.,
            norm_layer: nn.Module = nn.LayerNorm,
            act_layer: nn.Module = nn.GELU,
            mlp_layer: nn.Module = Mlp,
            init_values: Optional[float] = None,
            max_geom_weight: float = 0.66,
    ):
        super().__init__()
        self.norm1 = norm_layer(dim)

        self.perspective_transform = PerspectiveTransform(dim)

        # 路径A: 标准注意力，处理所有token，是模型的基石
        self.standard_attn = Attention(
            dim, num_heads=num_heads, qkv_bias=qkv_bias,
            attn_drop=attn_drop, proj_drop=proj_drop
        )

        # 路径B: 使用我们上面定义的、带有双重稳定器的V4版本
        self.region_attn = RelativeWindowMultiheadAttention(
            dim, num_heads=num_heads, dropout=attn_drop
        )

        # 融合机制: 允许模型自己决定何时以及多大程度上使用域信息
        self.max_geom_weight = max_geom_weight
        self.raw_gate = nn.Parameter(torch.tensor(-5.0))


        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

        # MLP部分，与V1完全一致
        self.norm2 = norm_layer(dim)

        self.channel_modulation = ChannelModulation(dim, act_layer=act_layer)

        self.mlp = mlp_layer(
            in_features=dim, hidden_features=int(dim * mlp_ratio),
            act_layer=act_layer, drop=proj_drop
        )
        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
        self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity()

    def forward(self, x: torch.Tensor, region_logits: Optional[torch.Tensor], has_cls_token: bool) -> torch.Tensor:
        shortcut1 = x
        y = self.norm1(x)

        # 1. 视角转换 <--- 修改
        y_std, y_geom = self.perspective_transform(y)

        # --- 并行计算 ---
        # 路径A: 标准注意力处理语义视角
        standard_attn_output = self.standard_attn(y_std)

        # 路径B: 域注意力处理几何视角
        if has_cls_token:
            y_patch = y_geom[:, 1:, :]
            logits_patch = region_logits
        else:
            y_patch = y_geom
            logits_patch = region_logits
        region_attn_output_patch = self.region_attn(y_patch, logits_patch)

        # --- 智能融合 (逻辑保持不变) ---
        w_geom = self.max_geom_weight * torch.sigmoid(self.raw_gate)
        w_std = 1.0 - w_geom

        if has_cls_token:
            standard_attn_cls = standard_attn_output[:, :1, :]
            standard_attn_patch = standard_attn_output[:, 1:, :]
            combined_patch_output = w_std * standard_attn_patch + w_geom * region_attn_output_patch
            combined_attn_output = torch.cat([standard_attn_cls, combined_patch_output], dim=1)
        else:
            combined_attn_output = w_std * standard_attn_output + w_geom * region_attn_output_patch

        # 第一次残差连接 (逻辑保持不变)
        x = shortcut1 + self.drop_path1(self.ls1(combined_attn_output))

        # --- MLP子层 (数据流修改) --- <--- 修改
        # 对应描述: Y = LayerNorm(X + X_mix)
        y_norm2 = self.norm2(x)

        # 对应描述: Y_cal = Y ⊙ W_mod
        y_calibrated = self.channel_modulation(y_norm2)

        # 对应描述: X_out = Y + MLP(Y_cal)
        # MLP处理校准后的特征，残差连接的输入是norm2的输出 `y_norm2`
        x = y_norm2 + self.drop_path2(self.ls2(self.mlp(y_calibrated)))

        return x

class Mesh_mae(nn.Module):
    def __init__(self, masking_ratio=0.5, channels=14, num_heads=12, encoder_depth=12, embed_dim=768,
                 decoder_num_heads=16, decoder_depth=6, decoder_embed_dim=512,
                 patch_size=1024, norm_layer=nn.LayerNorm, weight=0.2):
        super(Mesh_mae, self).__init__()
        patch_dim = channels
        self.num_patches = 1024
        self.weight = weight
        self.points_per_patch = 45
        self.embed_dim = embed_dim

        self.pos_embedding = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, embed_dim)
        )
        self.decoer_pos_embedding = nn.Sequential(
            nn.Linear(3, 128),
            nn.GELU(),
            nn.Linear(128, decoder_embed_dim)
        )
        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c h p -> b h (p c)', p=patch_size),
            nn.Linear(patch_dim * patch_size, embed_dim),
            nn.LayerNorm(embed_dim)
        )

        # cls_token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.encoder_cls_token_pos = nn.Parameter(torch.zeros(1, 1, embed_dim))

        # Transformer Encoder Blocks
        self.blocks = nn.ModuleList([
            Block(embed_dim, num_heads, mlp_ratio=4., qkv_bias=True, norm_layer=norm_layer)
            for _ in range(encoder_depth)
        ])

        self.norm = norm_layer(embed_dim)

        # 小型卷积网络用于分配 region 标签
        self.region_assigner = AdjacencyExpert(locator_path="./Right_locator_expert_robust.pth", template_path="./Right_canonical_template.pt" # 确保传入正确的devic
        )

        # --------------------------------------------------------------------------
        # MAE decoder specifics
        self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))

        self.decoder_blocks = nn.ModuleList([
            Block(decoder_embed_dim, decoder_num_heads, mlp_ratio=4., qkv_bias=True, norm_layer=norm_layer)
            for _ in range(decoder_depth)
        ])

        self.decoder_norm = norm_layer(decoder_embed_dim)

        # --------------------------------------------------------------------------

        self.to_points = nn.Linear(decoder_embed_dim, patch_size * 9)
        self.to_pointsnew = nn.Linear(decoder_embed_dim, self.points_per_patch * 3)
        self.to_points_seg = nn.Linear(decoder_embed_dim, 9)
        self.to_features = nn.Linear(decoder_embed_dim, patch_size * channels)
        self.to_features_seg = nn.Linear(decoder_embed_dim, channels)
        self.build_loss_func()
        self.initialize_weights()

    def build_loss_func(self):
        self.loss_func_cdl1 = ChamferDistanceL1().cuda()

    def initialize_weights(self):
        # 初始化权重
        torch.nn.init.normal_(self.cls_token, std=.02)
        torch.nn.init.normal_(self.mask_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        # (只在这里进行修改)
        if isinstance(m, nn.Linear):
            # we use xavier_uniform following official JAX ViT:
            torch.nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, faces, feats, centers, Fs, cordinates, ratio=0.25):

        patch_num = self.num_patches
        points_num_per_patch = self.points_per_patch

        feats_patches = feats
        centers_patches = centers
        faces_per_patch = centers_patches.shape[2]
        center_of_patches = torch.sum(centers_patches, dim=2) / faces_per_patch
        batch, channel, num_patches, *_ = feats_patches.shape
        cordinates_patches = cordinates

        # 求每个面中心坐标的嵌入张量
        pos_emb = self.pos_embedding(center_of_patches)
        encoder_cls_token_pos = self.encoder_cls_token_pos.repeat(batch, 1, 1)
        tokens = self.to_patch_embedding(feats_patches)  # [B, 1024, embed_dim=768]


        num_masked = int(ratio * patch_num)
        rand_indices = torch.rand(batch, patch_num).argsort(dim=-1).cuda()
        masked_indices, unmasked_indices = rand_indices[:, :num_masked], rand_indices[:, num_masked:]
        batch_range = torch.arange(batch, device=feats.device)[:, None]
        tokens_unmasked = tokens[batch_range, unmasked_indices]

        unmasked_logits = self.region_assigner(center_of_patches[batch_range, unmasked_indices])

        # cls_token 是分类token，通常用于Transformer模型的全局表示,此处将尺寸为(1,1024,embed_dim)的张量拓展为(batch,1024,embed_dim)
        cls_tokens = self.cls_token.expand(batch, -1, -1)
        tokens_unmasked = torch.cat((cls_tokens, tokens_unmasked), dim=1)
        pos_emb_a = torch.cat((encoder_cls_token_pos, pos_emb[batch_range, unmasked_indices]), dim=1)
        tokens_unmasked = tokens_unmasked + pos_emb_a

        for blk in self.blocks:
            # tokens_unmasked = blk(tokens_unmasked)
            tokens_unmasked = blk(tokens_unmasked, unmasked_logits, has_cls_token=True)
        tokens_unmasked = self.norm(tokens_unmasked)
        encoder_output_patches = self.decoder_embed(tokens_unmasked[:, 1:, :])  # 维度 [B, num_unmasked, C_dec]

        # 在decoder处才引入mask_token
        decoder_tokens = self.mask_token.repeat(batch, self.num_patches, 1)  # [B, 1024, C_dec]
        decoder_pos_emb = self.decoer_pos_embedding(center_of_patches)
        decoder_tokens[:, unmasked_indices, :] = encoder_output_patches
        decoded_tokens = decoder_tokens + decoder_pos_emb

        for blk in self.decoder_blocks:
            # decoded_tokens = blk(decoded_tokens)
            decoded_tokens = blk(decoded_tokens, region_logits=None, has_cls_token=False)  # 正确地迭代更新
        decoded_tokens = self.decoder_norm(decoded_tokens)

        # splice out the mask tokens and project to pixel values
        pred_tokens = decoded_tokens[batch_range, masked_indices]
        pred_vertices_coordinates = self.to_pointsnew(pred_tokens)
        # 调出每个patch中面的个数
        faces_values_per_patch = feats_patches.shape[-1]
        # 一个patch中包含64个面，45个顶点
        pred_vertices_coordinates = torch.reshape(pred_vertices_coordinates,
                                                  (batch * num_masked, points_num_per_patch, 3)).contiguous()

        # 处理顶点坐标
        cordinates_patches = cordinates_patches[batch_range, masked_indices]
        cordinates_patches = torch.reshape(cordinates_patches, (batch, num_masked, -1, 3)).contiguous()

        # 对于每个patch的顶点坐标，函数会检查并移除那些相同的（重复的）坐标
        cordinates_unique = torch.unique(cordinates_patches, dim=2)
        cordinates_unique = torch.reshape(cordinates_unique, (batch * num_masked, -1, 3)).contiguous()

        # ‘：’对应通道维度
        masked_feats_patches = feats_patches[batch_range, :, masked_indices]

        pred_faces_features = self.to_features(pred_tokens)
        pred_faces_features = torch.reshape(pred_faces_features, (batch, num_masked, channel, faces_values_per_patch))

        # calculate reconstruction loss
        # print(pred_vertices_coordinates.shape, cordinates_unique.shape)

        shape_con_loss = self.loss_func_cdl1(pred_vertices_coordinates, cordinates_unique)
        feats_con_loss = F.mse_loss(pred_faces_features, masked_feats_patches)

        # # 增加这行打印语句
        # print(f"Feats Loss: {feats_con_loss.item():.4f}, Shape Loss: {shape_con_loss.item():.4f}")

        loss = shape_con_loss + self.weight * feats_con_loss

        # loss = feats_con_loss + self.weight * shape_con_loss + quant_loss  # 总损失
        #######################################################################
        # if you are going to show the reconstruct shape, please using the following codes
        # pred_vertices_coordinates = pred_vertices_coordinates.reshape(batch, num_masked, -1, 3)
        # return loss, masked_indices, unmasked_indices, pred_vertices_coordinates, cordinates
        #######################################################################

        return loss, feats_con_loss, shape_con_loss