import torch
import torch.nn as nn
from modules.mlp import MLP
from torch_scatter import scatter
from pytorch3d.ops import knn_points, knn_gather

class ImplicitPhysicsLayer(nn.Module):
    """隐式物理约束层（替代原始SpringSystem）"""
    def __init__(self, feat_dim=128, K=8, max_points=150000, cache_update_freq=10):
        super().__init__()
        self.K = K
        self.max_points = max_points
        self.cache_update_freq = cache_update_freq 
        
        # 可学习参数模块
        self.mlp_stiffness = MLP([feat_dim, 32, 1], last_op=nn.Sigmoid())
        self.mlp_damping = MLP([feat_dim, 32, 1], last_op=nn.Sigmoid())
        
        # 初始化缓存邻域索引（加速训练）
        self.register_buffer('cache_idx', torch.zeros(max_points, K, dtype=torch.long))
        self.register_buffer('cache_step', torch.tensor(0, dtype=torch.long))
        self.register_buffer('cache_xyz', torch.zeros(max_points, 3))

    def forward(self, xyz, features, velocity, dt=0.01, use_cache=False):
        """
        输入:
            xyz:      [B, N, 3] 顶点坐标
            features: [B, N, D] 顶点特征（来自主网络）
            velocity: [B, N, 3] 顶点速度
            dt:       时间步长
            use_cache: 是否使用缓存的KNN索引（加速计算）
        输出:
            new_xyz:  [B, N, 3] 约束后的坐标
            loss_phy: 物理约束损失
        """
        B, N, _ = xyz.shape
        
        # Step 1: 动态构建K近邻或使用缓存
        should_update_cache = (
            not use_cache or 
            N > self.max_points or 
            self.cache_idx.sum() == 0 or
            self.cache_step % self.cache_update_freq == 0 or
            (self.cache_xyz[:N].norm() > 0 and (xyz[0, :N] - self.cache_xyz[:N]).abs().mean() > 0.02) 
        )
        
        if should_update_cache:
            _, knn_idx, _ = knn_points(xyz, xyz, K=self.K)  # [B, N, K]
            if N <= self.max_points:
                self.cache_idx[:N] = knn_idx[0]
                self.cache_xyz[:N] = xyz[0, :N].detach().clone()
        else:
            knn_idx = self.cache_idx[:N].expand(B, N, self.K)
        
        # 更新缓存步数
        self.cache_step += 1
        
        # Step 2: 计算顶点间相互作用权重（基于特征相似性）
        feat_neighbors = knn_gather(features, knn_idx)  # [B, N, K, D]
        feat_central = features.unsqueeze(2)            # [B, N, 1, D]
        
        # 注意力权重 [B, N, K]
        attn_logits = torch.sum(feat_central * feat_neighbors, dim=-1) / torch.sqrt(torch.tensor(features.size(-1)))
        attn_weights = torch.softmax(attn_logits, dim=-1)
        
        # Step 3: 参数化物理属性
        stiffness = 2.0 * self.mlp_stiffness(features.permute(0, 2, 1)).permute(0, 2, 1)  # [B, N, 1] → 刚度[0,2]
        damping = 1.0 * self.mlp_damping(features.permute(0, 2, 1)).permute(0, 2, 1)      # [B, N, 1] → 阻尼[0,1]
        
        # Step 4: 计算约束
        delta = knn_gather(xyz, knn_idx) - xyz.unsqueeze(2)  # [B, N, K, 3]
        dist = torch.norm(delta, dim=-1, keepdim=True)       # [B, N, K, 1]
        
        # 胡克定律
        force_hooke = stiffness.unsqueeze(2) * delta / (dist + 1e-6)  # [B, N, K, 3]
        
        # 阻尼力
        vel_neighbors = knn_gather(velocity, knn_idx)        # [B, N, K, 3]
        vel_diff = vel_neighbors - velocity.unsqueeze(2)
        force_damping = damping.unsqueeze(2) * vel_diff      # [B, N, K, 3]
        
        # 聚合力（加权平均）
        force_total = attn_weights.unsqueeze(-1) * (force_hooke + force_damping)
        B, N, K, _ = force_total.shape
        force_flat = force_total.view(B, N*K, 3)  # [B, N*K, 3]
        knn_idx_flat = knn_idx.view(B, N*K)       # [B, N*K]
        force_sum = scatter(force_flat, knn_idx_flat, dim=1, dim_size=N, reduce='sum')  # [B, N, 3]
        
        # Step 5: 更新坐标 & 计算损失
        new_xyz = xyz + dt * force_sum
        loss_phy = torch.mean(force_sum ** 2)  # 力平滑约束
        
        return new_xyz, loss_phy