# --------------------------------------------------------
# Pose Compositional Tokens
# Written by Zigang Geng (zigang@mail.ustc.edu.cn)
# --------------------------------------------------------

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist

from mmpose.models.builder import build_loss
from mmpose.models.builder import HEADS
from timm.models.layers import trunc_normal_

from .modules import MixerLayer
from .estimate_invisible_points import InvisiblePointEstimator


@HEADS.register_module()
class GroupedL2Loss(nn.Module):
    def __init__(self, head_indices=[0, 1, 2, 3, 4], arm_indices=[5, 6, 7, 8, 9, 10], leg_indices=[11, 12, 13, 14, 15, 16], arm_weight=2.0, leg_weight=2.0):
        super().__init__()
        self.head_indices = head_indices
        self.arm_indices = arm_indices
        self.leg_indices = leg_indices
        self.arm_weight = arm_weight
        self.leg_weight = leg_weight

    def forward(self, pred, gt):
        # 提取可见点信息
        visible = gt[:, :, -1].unsqueeze(-1)

        # 计算头和躯干关节点的损失
        head_loss = torch.sqrt(F.mse_loss(pred[:, self.head_indices, :], gt[:, self.head_indices, :-1], reduction='none') + 1e-6)
        head_loss_visible = (head_loss * visible[:, self.head_indices, :]).mean()

        # 计算手臂关节点的损失
        arm_loss = torch.sqrt(F.mse_loss(pred[:, self.arm_indices, :], gt[:, self.arm_indices, :-1], reduction='none') + 1e-6)
        arm_loss_visible = (arm_loss * visible[:, self.arm_indices, :]).mean() * self.arm_weight

        # 计算腿部关节点的损失
        leg_loss = torch.sqrt(F.mse_loss(pred[:, self.leg_indices, :], gt[:, self.leg_indices, :-1], reduction='none') + 1e-6)
        leg_loss_visible = (leg_loss * visible[:, self.leg_indices, :]).mean() * self.leg_weight

        # 组合损失
        reg_loss = head_loss_visible + arm_loss_visible + leg_loss_visible

        return reg_loss


class SELayer(nn.Module):
    def __init__(self, channel, reduction=16):
        super(SELayer, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),  # 改为GELU  nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)


@HEADS.register_module()
class PCT_Tokenizer(nn.Module):
    """ Tokenizer of Pose Compositional Tokens.
        paper ref: Zigang Geng et al. "Human Pose as
            Compositional Tokens"

    Args:
        stage_pct (str): Training stage (Tokenizer or Classifier).
        tokenizer (list): Config about the tokenizer.
        num_joints (int): Number of annotated joints in the dataset.
        guide_ratio (float): The ratio of image guidance.
        guide_channels (int): Feature Dim of the image guidance.
    """

    def __init__(self,
                 stage_pct,
                 tokenizer=None,
                 num_joints=17,
                 guide_ratio=0,
                 guide_channels=0):
        super().__init__()

        self.stage_pct = stage_pct
        self.guide_ratio = guide_ratio
        self.num_joints = num_joints
        
        # 不可见点估计器 (从codebook.py引入)
        self.invisible_estimator = InvisiblePointEstimator(min_visible_for_estimation=3)

        self.drop_rate = tokenizer['encoder']['drop_rate']     
        self.enc_num_blocks = tokenizer['encoder']['num_blocks']
        self.enc_hidden_dim = tokenizer['encoder']['hidden_dim']
        self.enc_token_inter_dim = tokenizer['encoder']['token_inter_dim']
        self.enc_hidden_inter_dim = tokenizer['encoder']['hidden_inter_dim']
        self.enc_dropout = tokenizer['encoder']['dropout']
        
        # 掩码噪声尺度 (从codebook.py引入)
        self.mask_noise_scale = tokenizer['encoder'].get('mask_noise_scale', 0.1)

        self.dec_num_blocks = tokenizer['decoder']['num_blocks']
        self.dec_hidden_dim = tokenizer['decoder']['hidden_dim']
        self.dec_token_inter_dim = tokenizer['decoder']['token_inter_dim']
        self.dec_hidden_inter_dim = tokenizer['decoder']['hidden_inter_dim']
        self.dec_dropout = tokenizer['decoder']['dropout']

        self.token_num = tokenizer['codebook']['token_num']
        self.token_class_num = tokenizer['codebook']['token_class_num']
        self.token_dim = tokenizer['codebook']['token_dim']
        self.decay = tokenizer['codebook']['ema_decay']

        self.invisible_token = nn.Parameter(
            torch.zeros(1, 1, self.enc_hidden_dim))
        trunc_normal_(self.invisible_token, mean=0., std=0.02, a=-0.02, b=0.02)

        if self.guide_ratio > 0:
            self.start_img_embed = nn.Linear(
                guide_channels, int(self.enc_hidden_dim*self.guide_ratio))
        self.start_embed = nn.Linear(
            2, int(self.enc_hidden_dim*(1-self.guide_ratio)))
        
        self.encoder = nn.ModuleList(
            [MixerLayer(self.enc_hidden_dim, self.enc_hidden_inter_dim, 
                self.num_joints, self.enc_token_inter_dim,
                self.enc_dropout) for _ in range(self.enc_num_blocks)])
        self.encoder_layer_norm = nn.LayerNorm(self.enc_hidden_dim)
        
        # 保持原始的单一码本设计
        self.token_mlp = nn.Linear(
            self.num_joints, self.token_num)
        self.feature_embed = nn.Linear(
            self.enc_hidden_dim, self.token_dim)

        self.register_buffer('codebook', 
            torch.empty(self.token_class_num, self.token_dim))
        self.codebook.data.normal_()
        self.register_buffer('ema_cluster_size', 
            torch.zeros(self.token_class_num))
        self.register_buffer('ema_w', 
            torch.empty(self.token_class_num, self.token_dim))
        self.ema_w.data.normal_()        
        
        self.decoder_token_mlp = nn.Linear(
            self.token_num, self.num_joints)
        self.decoder_start = nn.Linear(
            self.token_dim, self.dec_hidden_dim)

        self.decoder = nn.ModuleList(
            [MixerLayer(self.dec_hidden_dim, self.dec_hidden_inter_dim,
                self.num_joints, self.dec_token_inter_dim, 
                self.dec_dropout) for _ in range(self.dec_num_blocks)])
        self.decoder_layer_norm = nn.LayerNorm(self.dec_hidden_dim)

        # 可学习位置编码 (从codebook.py引入)
        self.learnable_positional_encoding = nn.Parameter(
            torch.randn(self.num_joints, 256)
        )
        nn.init.trunc_normal_(self.learnable_positional_encoding, mean=0., std=0.02)
        
        # 位置编码投影层 (从codebook.py引入)
        self.position_encoding_proj = nn.Linear(512, 256)

        self.recover_embed = nn.Linear(self.dec_hidden_dim, 2)
        
        # SE模块 (从codebook.py引入)
        self.se = SELayer(channel=17)
        
        # 分组损失函数 (从codebook.py引入)
        self.grouped_l2_loss = GroupedL2Loss(
            head_indices=[0, 1, 2, 3, 4],
            arm_indices=[5, 6, 7, 8, 9, 10],
            leg_indices=[11, 12, 13, 14, 15, 16]
        )

        self.loss = build_loss(tokenizer['loss_keypoint'])

    def position_encoding(self, joints_coord):
        """生成关节点坐标的位置编码 (从codebook.py引入)"""
        bs, num_joints, _ = joints_coord.shape
        enc_hidden_dim = 256
        position_encoding = torch.zeros((bs, num_joints, enc_hidden_dim), device=joints_coord.device)

        div_term = torch.exp(torch.arange(0, enc_hidden_dim, 2, device=joints_coord.device) * -(torch.log(torch.tensor(10000.0)) / enc_hidden_dim))
        position_encoding[:, :, 0::2] = torch.sin(joints_coord[:, :, 0:1] * div_term)
        position_encoding[:, :, 1::2] = torch.cos(joints_coord[:, :, 0:1] * div_term)

        # 可学习位置编码
        learnable_encoding = self.learnable_positional_encoding.unsqueeze(0).expand(bs, -1, -1)

        # 拼接
        combined_encoding = torch.cat((position_encoding, learnable_encoding), dim=-1)
        # 映射到256维
        combined_encoding = self.position_encoding_proj(combined_encoding)

        return combined_encoding

    def forward(self, joints, joints_feature, cls_logits, train=True):
        """Forward function. """

        if train or self.stage_pct == "tokenizer":
            # Encoder of Tokenizer, Get the PCT groundtruth class labels.
            joints_coord, joints_visible, bs \
                = joints[:,:,:-1], joints[:,:,-1].bool(), joints.shape[0]

            # 使用不可见点估计器 (从codebook.py引入)
            coords_for_embedding = self.invisible_estimator.estimate_invisible(
                joints_coord.clone(),
                joints_visible.clone()
            )

            # 使用改进的位置编码 (从codebook.py引入)
            encode_feat = self.start_embed(coords_for_embedding) + self.position_encoding(joints_coord)
            
            if self.guide_ratio > 0:
                # 应用SE模块 (从codebook.py引入)
                # joints_feature = joints_feature.view(bs, 17, 32, 32)
                # joints_feature = self.se(joints_feature)
                # joints_feature = joints_feature.view(bs, 17, -1)
                encode_img_feat = self.start_img_embed(joints_feature)
                encode_feat = torch.cat((encode_feat, encode_img_feat), dim=2)

            # 改进的掩码策略：使用噪声替代全0 (从codebook.py引入)
            if train and self.stage_pct == "tokenizer":
                rand_mask_ind = torch.rand(
                    joints_visible.shape, device=joints.device) > self.drop_rate
                joints_visible = torch.logical_and(rand_mask_ind, joints_visible) 

            w = joints_visible.unsqueeze(-1).type_as(encode_feat)
            # 为所有特征位置生成高斯噪声
            noise = torch.randn_like(encode_feat) * self.mask_noise_scale
            encode_feat = encode_feat * w + noise * (1 - w)
                    
            for num_layer in self.encoder:
                encode_feat = num_layer(encode_feat)
            encode_feat = self.encoder_layer_norm(encode_feat)
            
            # 保持原始的token化过程
            encode_feat = encode_feat.transpose(2, 1)
            encode_feat = self.token_mlp(encode_feat).transpose(2, 1)
            encode_feat = self.feature_embed(encode_feat).flatten(0,1)
            
            distances = torch.sum(encode_feat**2, dim=1, keepdim=True) \
                + torch.sum(self.codebook**2, dim=1) \
                - 2 * torch.matmul(encode_feat, self.codebook.t())
                
            encoding_indices = torch.argmin(distances, dim=1)
            encodings = torch.zeros(
                encoding_indices.shape[0], self.token_class_num, device=joints.device)
            encodings.scatter_(1, encoding_indices.unsqueeze(1), 1)
        else:
            bs = cls_logits.shape[0] // self.token_num
            encoding_indices = None
        
        if self.stage_pct == "classifier":
            part_token_feat = torch.matmul(cls_logits, self.codebook)
        else:
            part_token_feat = torch.matmul(encodings, self.codebook)

        # if train and self.stage_pct == "tokenizer":
        #     # Updating Codebook using EMA
        #     dw = torch.matmul(encodings.t(), encode_feat.detach())
        #     # sync
        #     n_encodings, n_dw = encodings.numel(), dw.numel()
        #     encodings_shape, dw_shape = encodings.shape, dw.shape
        #     combined = torch.cat((encodings.flatten(), dw.flatten()))
        #     dist.all_reduce(combined) # math sum
        #     sync_encodings, sync_dw = torch.split(combined, [n_encodings, n_dw])
        #     sync_encodings, sync_dw = \
        #         sync_encodings.view(encodings_shape), sync_dw.view(dw_shape)

        #     self.ema_cluster_size = self.ema_cluster_size * self.decay + \
        #                             (1 - self.decay) * torch.sum(sync_encodings, 0)
            
        #     n = torch.sum(self.ema_cluster_size.data)
        #     self.ema_cluster_size = (
        #         (self.ema_cluster_size + 1e-5)
        #         / (n + self.token_class_num * 1e-5) * n)
            
        #     self.ema_w = self.ema_w * self.decay + (1 - self.decay) * sync_dw
        #     self.codebook = self.ema_w / self.ema_cluster_size.unsqueeze(1)
        #     e_latent_loss = F.mse_loss(part_token_feat.detach(), encode_feat)
        #     part_token_feat = encode_feat + (part_token_feat - encode_feat).detach()
        # else:
        #     e_latent_loss = None
        if train and self.stage_pct == "tokenizer":
            # Updating Codebook using EMA
            dw = torch.matmul(encodings.t(), encode_feat.detach())
            
            # 直接使用本地数据，无需分布式同步
            sync_encodings, sync_dw = encodings, dw

            self.ema_cluster_size = self.ema_cluster_size * self.decay + \
                                    (1 - self.decay) * torch.sum(sync_encodings, 0)
            
            n = torch.sum(self.ema_cluster_size.data)
            self.ema_cluster_size = (
                (self.ema_cluster_size + 1e-5)
                / (n + self.token_class_num * 1e-5) * n)
            
            self.ema_w = self.ema_w * self.decay + (1 - self.decay) * sync_dw
            self.codebook = self.ema_w / self.ema_cluster_size.unsqueeze(1)
            e_latent_loss = F.mse_loss(part_token_feat.detach(), encode_feat)
            part_token_feat = encode_feat + (part_token_feat - encode_feat).detach()
        else:
            e_latent_loss = None
        
        # Decoder of Tokenizer, Recover the joints.
        part_token_feat = part_token_feat.view(bs, -1, self.token_dim)
        
        part_token_feat = part_token_feat.transpose(2,1)
        part_token_feat = self.decoder_token_mlp(part_token_feat).transpose(2,1)
        decode_feat = self.decoder_start(part_token_feat)
        

        for num_layer in self.decoder:
            decode_feat = num_layer(decode_feat)
        decode_feat = self.decoder_layer_norm(decode_feat)

        recoverd_joints = self.recover_embed(decode_feat)

        return recoverd_joints, encoding_indices, e_latent_loss

    def get_loss(self, output_joints, joints, e_latent_loss):
        """计算损失 (引入分组损失设计)"""
        losses = dict()

        kpt_loss, e_latent_loss = self.loss(output_joints, joints, e_latent_loss)
        
        # 添加分组损失 (从codebook.py引入)
        grouped_loss = self.grouped_l2_loss(output_joints, joints)

        # 加权平均
        total_loss = 0.8 * kpt_loss + 0.2 * grouped_loss

        losses['joint_loss'] = total_loss
        losses['e_latent_loss'] = e_latent_loss

        return losses

    def init_weights(self, pretrained=""):
        """Initialize model weights."""

        parameters_names = set()
        for name, _ in self.named_parameters():
            parameters_names.add(name)

        buffers_names = set()
        for name, _ in self.named_buffers():
            buffers_names.add(name)

        if os.path.isfile(pretrained):
            assert (self.stage_pct == "classifier"), \
                "Training tokenizer does not need to load model"
            pretrained_state_dict = torch.load(pretrained, 
                            map_location=lambda storage, loc: storage)

            need_init_state_dict = {}

            for name, m in pretrained_state_dict['state_dict'].items():
                if 'keypoint_head.tokenizer.' in name:
                    name = name.replace('keypoint_head.tokenizer.', '')
                if name in parameters_names or name in buffers_names:
                    need_init_state_dict[name] = m
            self.load_state_dict(need_init_state_dict, strict=True)
        else:
            if self.stage_pct == "classifier":
                print('If you are training a classifier, '\
                    'must check that the well-trained tokenizer '\
                    'is located in the correct path.')