import math
import torch
import torch.nn as nn

class PointCloudEncoder(nn.Module):
    def __init__(self, max_length, num_out_scale=1):
        super().__init__()
        self.base_dim = 8
        self.num_out_scale = num_out_scale
        self.max_length = max_length

        # 点坐标嵌入层
        self.pc_embed = nn.Linear(3, self.base_dim)

        # 可学习的位置编码
        self.learnable_pos = nn.Parameter(torch.zeros(1, max_length, self.base_dim))
        nn.init.normal_(self.learnable_pos, mean=0.0, std=0.02)

        # 可学习的频率参数体系
        div_term = torch.exp(torch.arange(0, self.base_dim, 2).float() * (-math.log(10000.0) / self.base_dim))
        self.freq = nn.Parameter(div_term.clone(), requires_grad=True)
        self.freq_scale = nn.Parameter(torch.ones(1))  # 频率缩放因子

        # 可学习的相位参数（带小扰动初始化）
        self.phase = nn.Parameter(torch.zeros(self.base_dim // 2))
        nn.init.uniform_(self.phase, -0.1, 0.1)  

        # 混合编码权重系数（带sigmoid约束）
        self.alpha = nn.Parameter(torch.tensor(0.5))  # 初始权重0.5

        # 维度扩展卷积
        self.expand_conv = nn.Conv1d(
            in_channels=self.base_dim,
            out_channels=self.base_dim * num_out_scale,
            kernel_size=1
        )

        # 保持频率参数的初始特性
        with torch.no_grad():
            self.freq.copy_(div_term) 

    def forward(self, x):
        # x形状: [1, num_points, 3]
        num_points = x.size(1)
        
        # 点云坐标嵌入
        pc_embed = self.pc_embed(x)  # [1, num_points, base_dim]
        
        # 动态生成sin/cos编码
        positions = torch.arange(num_points, dtype=torch.float, device=x.device).unsqueeze(1)  # [num_points, 1]
        
        # 应用频率缩放因子
        scaled_freq = self.freq_scale * self.freq  # [base_dim//2]
        angle = positions * scaled_freq.unsqueeze(0) + self.phase.unsqueeze(0)  # [num_points, base_dim//2]
        
        # 构建编码矩阵
        sincos_enc = torch.zeros(num_points, self.base_dim, device=x.device)
        sincos_enc[:, 0::2] = torch.sin(angle)
        sincos_enc[:, 1::2] = torch.cos(angle)
        sincos_pos = sincos_enc.unsqueeze(0)  # [1, num_points, base_dim]

        # 获取可学习位置编码
        learn_pos = self.learnable_pos[:, :num_points, :]
        
        # 混合编码（带sigmoid约束的权重）
        mix_alpha = torch.sigmoid(self.alpha)  # 约束到(0,1)之间
        combined_pos = mix_alpha * learn_pos + (1 - mix_alpha) * sincos_pos
        
        # 特征融合
        encoded = pc_embed + combined_pos
        
        # 维度调整
        encoded = encoded.permute(0, 2, 1)  # [1, base_dim, num_points]
        expanded = self.expand_conv(encoded)  # [1, base_dim*num_out_scale, num_points]
        expanded = expanded.permute(0, 2, 1)  # [1, num_points, base_dim*num_out_scale]
        
        return expanded

    def extra_repr(self):
        """可视化重要参数"""
        return f"freq_scale={self.freq_scale.item():.3f}, mix_alpha={torch.sigmoid(self.alpha).item():.3f}"
    

if __name__ == "__main__":
    # 初始化模型
    encoder = PointCloudEncoder(max_length=277410, num_out_scale=4).cuda()
    # 前向传播测试
    input_data = torch.rand(1, 277410, 3).cuda()
    output = encoder(input_data)
    print(output.shape)
    
