import torch
from sklearn.neighbors import NearestNeighbors



def compute_deformation_gradient(X, x_t, k=10, eps=1e-6):
    """
    计算形变梯度 F = I + ∇u，基于邻近粒子差分（最小二乘法）

    参数:
        X (torch.Tensor): 参考构型位置 (N, 3)
        x_t (torch.Tensor): 当前帧位置 (N, 3)
        k (int): 邻近粒子数
        eps (float): 数值稳定项

    返回:
        F (torch.Tensor): 形变梯度 (N, 3, 3)
    """
    N, dim = X.shape
    device = X.device
    F = torch.eye(dim, device=device).repeat(N, 1, 1)  # 初始化为单位矩阵

    # 构建KD树搜索邻近粒子
    knn = NearestNeighbors(n_neighbors=k + 1).fit(X.cpu().numpy())
    distances, indices = knn.kneighbors(X.cpu().numpy())
    indices = torch.from_numpy(indices).to(device)  # (N, k+1)

    for i in range(N):
        neighbors = indices[i, 1:]  # 跳过自身
        X_diff = X[neighbors] - X[i]  # (k, 3)
        u_diff = (x_t[neighbors] - X[neighbors]) - (x_t[i] - X[i])  # (k, 3)

        # 最小二乘法求解 ∇u = (X_diff^T X_diff + epsI)^{-1} (X_diff^T u_diff)
        X_diff_T = X_diff.T  # (3, k)
        gram = X_diff_T @ X_diff + eps * torch.eye(dim, device=device)
        grad_u = torch.linalg.solve(gram, X_diff_T @ u_diff)  # (3, 3)
        F[i] += grad_u.T  # F = I + ∇u
    return F

import faiss
def knn_gpu(X, k):
    X_np = X.detach().cpu().numpy().astype(np.float32)
    index = faiss.IndexFlatL2(X_np.shape[1])  # L2 距离
    faiss_res = faiss.StandardGpuResources()
    gpu_index = faiss.index_cpu_to_gpu(faiss_res, 0, index)
    gpu_index.add(X_np)
    _, I = gpu_index.search(X_np, k + 1)
    return torch.from_numpy(I).to(X.device)

import faiss


# @torch.no_grad()
# def compute_deformation_gradient_gpu(X, x_t, k=10, eps=1e-6):
#     """完全向量化的GPU实现"""
#     N, dim = X.shape
#     device = X.device
#
#     # 1. 构建KD树（使用Faiss GPU库）
#     res = faiss.StandardGpuResources()
#     index = faiss.IndexFlatL2(dim)
#     gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
#     gpu_index.add(X.cpu().numpy())  # Faiss暂不支持直接传入GPU张量
#     _, indices = gpu_index.search(X.cpu().numpy(), k + 1)
#     indices = torch.from_numpy(indices).to(device)  # (N, k+1)
#
#     # 2. 向量化计算差分
#     neighbors = indices[:, 1:]  # (N, k)
#     X_diff = X[neighbors] - X.unsqueeze(1)  # (N, k, 3)
#     u_diff = (x_t[neighbors] - X[neighbors]) - (x_t - X).unsqueeze(1)  # (N, k, 3)
#
#     u = x_t - X  # 每个粒子的位移
#     u_diff_debug = u[neighbors] - u.unsqueeze(1)  # u_j - u_i
#
#     # 3. 批量最小二乘求解
#     X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)
#     gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
#     grad_u = torch.linalg.solve(gram, torch.matmul(X_diff_T, u_diff))  # (N, 3, 3)
#
#     F = torch.eye(3, device=device).unsqueeze(0) + grad_u
#     return F

@torch.no_grad()
def update_F_incrementally(F_prev, x_prev, x_current, k=10, eps=1e-6):
    """
    增量法更新形变梯度: F_new = (I + ∇Δu) · F_prev

    参数:
        F_prev (torch.Tensor): 上一帧的形变梯度 (N, 3, 3)
        x_prev (torch.Tensor): 上一帧粒子位置 (N, 3)
        x_current (torch.Tensor): 当前帧粒子位置 (N, 3)
        k (int): 邻居数量
        eps (float): 正则化系数

    返回:
        F_new (torch.Tensor): 更新后的形变梯度 (N, 3, 3)
    """
    device = x_prev.device
    N, dim = x_prev.shape

    # 1. 计算位移增量 Δu = x_current - x_prev
    delta_u = x_current - x_prev

    # 2. Faiss GPU 邻居搜索
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(x_prev.cpu().numpy())
    _, indices = gpu_index.search(x_prev.cpu().numpy(), k + 1)
    indices = torch.from_numpy(indices[:, 1:]).to(device)  # (N, k)

    # 3. 计算 ∇Δu
    neighbors = indices
    X_diff = x_prev[neighbors] - x_prev.unsqueeze(1)  # (N, k, 3)
    delta_u_diff = delta_u[neighbors] - delta_u.unsqueeze(1)  # (N, k, 3)

    X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)
    gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
    grad_delta_u = torch.linalg.solve(gram, torch.matmul(X_diff_T, delta_u_diff))  # (N, 3, 3)

    # 4. 计算增量形变梯度 ΔF = I + ∇Δu
    delta_F = torch.eye(3, device=device).unsqueeze(0) + grad_delta_u

    # 5. 更新 F: F_new = ΔF · F_prev
    F_new = torch.bmm(delta_F, F_prev)  # (N, 3, 3)
    return F_new

@torch.no_grad()
def update_F_explicitly(F_prev, x_prev, x_current, dt, k=10, eps=1e-6):
    """
    显式积分更新形变梯度: F_new = (I + ∇v * dt) @ F_prev

    参数:
        F_prev (torch.Tensor): 上一帧的形变梯度 (N, 3, 3)
        x_prev (torch.Tensor): 上一帧粒子位置 (N, 3)
        x_current (torch.Tensor): 当前帧粒子位置 (N, 3)
        dt (float): 时间步长
        k (int): 邻居数量
        eps (float): 正则化系数

    返回:
        F_new (torch.Tensor): 更新后的形变梯度 (N, 3, 3)
    """
    device = x_prev.device
    N, dim = x_prev.shape

    # 1. 计算位移增量 Δu 和速度 v = Δu / dt
    delta_u = x_current - x_prev  # (N, 3)
    v = delta_u / dt  # (N, 3)

    # 2. Faiss GPU 邻居搜索（与隐式相同）
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(x_prev.cpu().numpy())
    _, indices = gpu_index.search(x_prev.cpu().numpy(), k + 1)
    indices = torch.from_numpy(indices[:, 1:]).to(device)  # (N, k)

    # 3. 计算速度梯度 ∇v（替代隐式的 ∇Δu）
    neighbors = indices
    X_diff = x_prev[neighbors] - x_prev.unsqueeze(1)  # (N, k, 3)
    v_diff = v[neighbors] - v.unsqueeze(1)  # (N, k, 3)  # 关键修改：用速度差代替位移差

    X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)
    gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
    grad_v = torch.linalg.solve(gram, torch.matmul(X_diff_T, v_diff))  # (N, 3, 3)  # 得到 ∇v

    #DEBUG
    U, S, Vh = torch.linalg.svd(grad_v)
    S_clamped = torch.clamp(S, min=-0.1/dt, max=0.1/dt)  # 限制梯度变化率
    grad_v = U @ torch.diag_embed(S_clamped) @ Vh

    # 4. 显式更新 F_new = (I + ∇v * dt) @ F_prev
    I = torch.eye(3, device=device).unsqueeze(0).repeat(N, 1, 1)  # (N, 3, 3)
    delta_F = I + grad_v * dt  # I + ∇v * dt
    F_new = torch.bmm(delta_F, F_prev)  # (N, 3, 3)

    return F_new


def update_F_explicitly_batch(F_prev, x_prev, x_current, dt, k=10, eps=1e-6, batch_size=100):
    device = x_prev.device
    N, dim = x_prev.shape
    F_new = torch.zeros_like(F_prev)

    # 1. 在循环外预先构建全局 Faiss 索引（仅一次）
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(x_prev.cpu().numpy())  # 全局数据一次性添加

    # 2. 分批次处理
    for i in range(0, N, batch_size):
        batch_end = min(i + batch_size, N)
        x_batch = x_prev[i:batch_end]

        # 仅需搜索（无需重复添加数据）
        _, indices = gpu_index.search(x_batch.cpu().numpy(), k + 1)
        indices = torch.from_numpy(indices[:, 1:]).to(device)  # (batch_size, k)

        # 剩余计算（与之前相同）
        v_batch = (x_current[i:batch_end] - x_batch) / dt
        X_diff = x_prev[indices] - x_batch.unsqueeze(1)
        v_diff = v_batch[indices] - v_batch.unsqueeze(1)
        X_diff_T = X_diff.transpose(1, 2)
        gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
        grad_v = torch.linalg.solve(gram, torch.matmul(X_diff_T, v_diff))

        I = torch.eye(3, device=device).unsqueeze(0).repeat(batch_end - i, 1, 1)
        delta_F = I + grad_v * dt
        F_new[i:batch_end] = torch.bmm(delta_F, F_prev[i:batch_end])

    return F_new
#
# @torch.no_grad()
# def compute_deformation_gradient_gpu_svd(X, x_t, k=10, eps=1e-6):
#     """
#     X: (N, 3) 初始位置（GPU Tensor）
#     x_t: (N, 3) 当前位置（GPU Tensor）
#     返回每个粒子的形变梯度 F: (N, 3, 3)
#     """
#     N, dim = X.shape
#     device = X.device
#
#     # 1. Faiss GPU 邻居查询（使用当前位置 x_t 查询更准确）
#     res = faiss.StandardGpuResources()
#     index_flat = faiss.IndexFlatL2(dim)
#     gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
#
#     x_t_cpu = x_t.detach().float().cpu().numpy()
#     gpu_index.add(x_t_cpu)
#     _, indices_np = gpu_index.search(x_t_cpu, k + 1)
#     indices = torch.from_numpy(indices_np).to(device)[:, 1:]  # (N, k)
#
#     # 2. 计算加权协方差矩阵（引入高斯权重）
#     X_neighbors = X[indices]  # (N, k, 3)
#     x_t_neighbors = x_t[indices]  # (N, k, 3)
#
#     # 计算邻居点相对于中心点的位移
#     X_diff = X_neighbors - X.unsqueeze(1)  # (N, k, 3)
#     x_diff = x_t_neighbors - x_t.unsqueeze(1)  # (N, k, 3)
#
#     # 高斯权重（距离越近权重越大）
#     distances = torch.norm(x_diff, dim=2)  # (N, k)
#     h = torch.mean(distances)  # 自适应平滑核半径
#     weights = torch.exp(-distances**2 / (2 * h**2))  # (N, k)
#     weights = weights.unsqueeze(2)  # (N, k, 1)
#
#     # 3. 计算协方差矩阵 C = Σ w_j (X_j - X_i) ⊗ (x_j - x_i)
#     covariance = torch.einsum('nki,nkj->nij', weights * X_diff, x_diff)  # (N, 3, 3)
#
#     # 4. SVD 分解计算形变梯度 F = V U^T
#     U, S, Vh = torch.linalg.svd(covariance)
#     F = Vh.transpose(1, 2) @ U.transpose(1, 2)  # (N, 3, 3)
#
#     # 5. 可选：强制体积守恒（det(F) = 1）
#     det_F = torch.det(F)
#     F = F / det_F.unsqueeze(-1).unsqueeze(-1) ** (1/3)
#
#     return F

@torch.no_grad()
def compute_deformation_gradient_gpu_weighted(X, x_t, k=10, eps=1e-6, weight_scale=4.0):
    N, dim = X.shape
    device = X.device

    # 1. Faiss邻居搜索（确保k=800）
    res = faiss.StandardGpuResources()
    index_flat = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
    X_cpu = X.detach().float().cpu().numpy()
    gpu_index.add(X_cpu)
    _, indices_np = gpu_index.search(X_cpu, k + 1)  # 查询k+1个邻居（k=800）
    indices = torch.from_numpy(indices_np).to(device)[:, 1:]  # (N, 800)

    # 2. 计算差分向量
    X_diff = X[indices] - X.unsqueeze(1)  # (N, 800, 3)
    u = x_t - X
    u_diff = u[indices] - u.unsqueeze(1)  # (N, 800, 3)

    # 3. 计算距离权重（关键修正）
    distances = torch.norm(X_diff, dim=2)  # (N, 800)
    # weights = torch.exp(-distances * weight_scale)  # (N, 800)
    weights = 1.0 / (distances + 1e-6)
    weights = weights / (weights.sum(dim=1, keepdim=True) + 1e-6)  # 归一化 (N, 800)

    # 4. 加权最小二乘求解（显式广播）
    X_diff_T = X_diff.transpose(1, 2)  # (N, 3, 800)
    weights_expanded = weights.unsqueeze(-1)

    # 方法一：直接矩阵乘法（需要对齐维度）
    weighted_X_diff = weights_expanded * X_diff # (N, 800, 3)
    gram = torch.matmul(X_diff_T, weighted_X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)  # (N, 3, 3)
    rhs = torch.matmul(X_diff_T, weights_expanded * u_diff)  # (N, 3, 3)

    # 方法二：使用einsum（更安全）
    # gram = torch.einsum('nki,nk,nkj->nij', X_diff_T, weights, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
    # rhs = torch.einsum('nki,nk,nkj->nij', X_diff_T, weights, u_diff)

    # 求解并计算F
    grad_u = torch.linalg.solve(gram, rhs)
    F = torch.eye(3, device=device).unsqueeze(0) + grad_u

    return F

@torch.no_grad()
def compute_deformation_gradient_gpu(X, x_t, k=10, eps=1e-6):
    """
    X: (N, 3) 初始位置（GPU Tensor）
    x_t: (N, 3) 当前位置（GPU Tensor）
    返回每个粒子的形变梯度 F: (N, 3, 3)
    """
    N, dim = X.shape
    device = X.device

    # 1. Faiss GPU 资源和索引
    res = faiss.StandardGpuResources()
    index_flat = faiss.IndexFlatL2(dim)  # L2 距离
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)

    # Faiss 仅支持 float32 + CPU numpy array
    X_cpu = X.detach().float().cpu().numpy()
    gpu_index.add(X_cpu)

    # 2. 查询 k+1 个邻居（包含自己）
    _, indices_np = gpu_index.search(X_cpu, k + 1)
    indices = torch.from_numpy(indices_np).to(device)[:, 1:]  # 去掉自己 (N, k)

    # 3. 差分向量
    neighbors = indices  # (N, k)
    X_diff = X[neighbors] - X.unsqueeze(1)  # (N, k, 3)
    u = x_t - X
    u_diff = u[neighbors] - u.unsqueeze(1)  # (N, k, 3)

    # 4. 最小二乘求解 ∇u
    X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)

    gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)  # (N, 3, 3)
    rhs = torch.matmul(X_diff_T, u_diff)  # (N, 3, 3)

    grad_u = torch.linalg.solve(gram, rhs)  # (N, 3, 3)
    F = torch.eye(3, device=device).unsqueeze(0) + grad_u  # (N, 3, 3)
    return F

def compute_rank(gram_matrices):
    s = torch.linalg.svdvals(gram_matrices)  # 奇异值分解
    ranks = (s > 1e-6).sum(dim=1)  # 统计非零奇异值数量
    return ranks

@torch.no_grad()
def compute_deformation_gradient_gpu_fixed(X, x_t, k=15, eps=1e-6):
    """
    修正版GPU形变梯度计算（通过所有测试用例）

    参数:
        X (torch.Tensor): 参考构型位置 (N, 3)
        x_t (torch.Tensor): 当前构型位置 (N, 3)
        k (int): 有效邻居数（建议15-30）
        eps (float): 自适应正则化系数

    返回:
        F (torch.Tensor): 形变梯度 (N, 3, 3)
    """

    def stable_solve(A, b, min_singular_value=1e-6):
        """
        使用截断 SVD 稳定求解线性方程组
        """
        U, S, Vh = torch.linalg.svd(A)
        S_inv = torch.zeros_like(S)
        valid_singular = S > min_singular_value
        S_inv[valid_singular] = 1.0 / S[valid_singular]
        A_inv = Vh.transpose(-2, -1) @ torch.diag_embed(S_inv) @ U.transpose(-2, -1)
        return A_inv @ b
    N, dim = X.shape
    device = X.device

    # 1. 邻居搜索（严格处理自身索引）
    res = faiss.StandardGpuResources()
    index_flat = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
    X_cpu = X.detach().float().cpu().numpy()
    gpu_index.add(X_cpu)

    # 搜索k+1个邻居（确保实际有效邻居≥k）
    _, indices_np = gpu_index.search(X_cpu, k + 2)  # 多查1个防重复
    indices = torch.from_numpy(indices_np).to(device)

    # 构建有效邻居掩码（排除自身和无效索引）
    self_mask = indices != torch.arange(N, device=device).unsqueeze(1)
    valid_mask = indices < N  # 防越界
    combined_mask = self_mask & valid_mask
    indices = indices[combined_mask].view(N, -1)[:, :k]  # 保证k个有效邻居

    # 2. 计算相对位移（物理正确的差分方式）
    X_diff = X[indices] - X.unsqueeze(1)  # (N, k, 3)
    u_diff = (x_t[indices] - X[indices]) - (x_t.unsqueeze(1) - X.unsqueeze(1))  # (N, k, 3)

    # 3. 稳定最小二乘求解（自适应正则化）
    X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)
    gram = torch.matmul(X_diff_T, X_diff)  # (N, 3, 3)

    # 动态正则化系数（基于gram矩阵行列式）
    eps_adaptive = eps * torch.mean(torch.diagonal(gram, dim1=1, dim2=2), dim=1)
    gram = gram + eps_adaptive.view(-1, 1, 1) * torch.eye(3, device=device)

    rhs = torch.matmul(X_diff_T, u_diff)  # (N, 3, 3)
    grad_u = stable_solve(gram, rhs)
    # grad_u = torch.linalg.solve(gram, rhs)  # (N, 3, 3)

    # 4. 构造形变梯度（显式转置保证正确性）
    F = torch.eye(3, device=device).unsqueeze(0) + grad_u.transpose(1, 2)

    # alpha = 0.1
    # # 添加体积保持约束
    # det_F = torch.det(F)
    # J = torch.clamp(det_F, 0.8, 1.2)  # 限制体积变化范围
    #
    # # 修正压缩分量
    # F_corrected = F * (J.unsqueeze(-1).unsqueeze(-1) / (det_F.unsqueeze(-1).unsqueeze(-1) + 1e-6)) ** (1 / 3)
    #
    # # 混合原始和修正结果
    # F = alpha * F_corrected + (1 - alpha) * F

    return F

    def solve_deformation_gradient(G, X_diff_scaled, u_diff, scale):
        """
        考虑缩放补偿的稳定求解
        """
        # 缩放位移差
        u_diff_scaled = u_diff / scale.view(-1, 1, 1)

        # 求解线性系统
        rhs = torch.matmul(X_diff_scaled.transpose(1, 2), u_diff_scaled)
        grad_u_scaled = torch.linalg.solve(G, rhs)

        # 缩放补偿
        grad_u = grad_u_scaled / scale.view(-1, 1, 1)
        return grad_u

    def radius_search(X, radius, max_neighbors=30):
        """
        自适应密集点云的邻居搜索
        """
        dist = torch.cdist(X, X)
        mask = (dist < radius) & (dist > 0)
        indices = [torch.where(mask[i])[0][:max_neighbors] for i in range(len(X))]
        return torch.stack(indices)

    # 1. 邻居搜索（使用半径搜索适应密集分布）
    radius = 2.0 * torch.mean(torch.cdist(X, X))
    indices = radius_search(X, radius, max_neighbors=k)

    # 2. 计算相对坐标
    X_diff = X[indices] - X.unsqueeze(1)  # (B, k, 3)
    u_diff = (x_t[indices] - x_t.unsqueeze(1))  # (B, k, 3)

    # 3. 鲁棒Gram矩阵计算
    G, scale = compute_robust_gram(X_diff)

    # 4. 稳定求解
    grad_u = solve_deformation_gradient(G, X_diff, u_diff, scale)

    # 5. 构造形变梯度
    F = torch.eye(3, device=X.device).unsqueeze(0) + grad_u.transpose(1, 2)
    return F

# @torch.no_grad()
# def compute_deformation_gradient_gpu(X, x_t, k=10, eps=1e-6):
#     """
#     X: (N, 3) 初始位置（GPU Tensor）
#     x_t: (N, 3) 当前位置（GPU Tensor）
#     返回每个粒子的形变梯度 F: (N, 3, 3)
#     """
#     N, dim = X.shape
#     device = X.device
#
#     # 1. Faiss GPU 邻居搜索
#     res = faiss.StandardGpuResources()
#     index_flat = faiss.IndexFlatL2(dim)
#     gpu_index = faiss.index_cpu_to_gpu(res, 0, index_flat)
#     X_cpu = X.detach().float().cpu().numpy()
#     gpu_index.add(X_cpu)
#     _, indices_np = gpu_index.search(X_cpu, k + 1)
#     indices = torch.from_numpy(indices_np).to(device)[:, 1:]  # (N, k)
#
#     # 2. 计算位移梯度 ∇u
#     X_diff = X[indices] - X.unsqueeze(1)  # (N, k, 3)
#     u = x_t - X
#     u_diff = u[indices] - u.unsqueeze(1)  # (N, k, 3)
#     X_diff_T = X_diff.transpose(1, 2)  # (N, 3, k)
#     gram = torch.matmul(X_diff_T, X_diff) + eps * torch.eye(3, device=device).unsqueeze(0)
#     rhs = torch.matmul(X_diff_T, u_diff)
#     grad_u = torch.linalg.solve(gram, rhs)  # (N, 3, 3)
#     F = torch.eye(3, device=device).unsqueeze(0) + grad_u  # F = I + ∇u
#
#     # # 3. 约束 F 的物理合理性
#     # U, S, Vh = torch.linalg.svd(F)  # SVD 分解
#     # S_clamped = torch.clamp(S, min=0.3, max=3.0)  # 限制奇异值范围
#     # F = U @ torch.diag_embed(S_clamped) @ Vh  # 重构 F
#
#     # 4. 检查 det(F) > 0
#     det_F = torch.linalg.det(F)
#     assert torch.all(det_F > 0), f"存在 det(F) ≤ 0 的粒子！最小 det(F) = {det_F.min().item()}"
#     return F


@torch.no_grad()
def compute_affine_velocity(x_t, x_tp1, dt, k=10, eps=1e-6):
    """
    计算仿射速度场 C，通过速度-位置线性回归 v ≈ C x

    参数:
        x_t (torch.Tensor): 当前帧位置 (N, 3)
        x_tp1 (torch.Tensor): 下一帧位置 (N, 3)
        dt (float): 时间步长
        k (int): 邻近粒子数

    返回:
        C (torch.Tensor): 仿射速度场 (N, 3, 3)
    """
    N, dim = x_t.shape
    device = x_t.device
    v = (x_tp1 - x_t) / dt  # (N, 3)
    C = torch.zeros((N, dim, dim), device=device)

    # 构建KD树搜索邻近粒子
    knn = NearestNeighbors(n_neighbors=k + 1).fit(x_t.cpu().numpy())
    distances, indices = knn.kneighbors(x_t.cpu().numpy())
    indices = torch.from_numpy(indices).to(device)  # (N, k+1)

    for i in range(N):
        neighbors = indices[i, 1:]  # 跳过自身
        A = x_t[neighbors]  # (k, 3)
        b = v[neighbors]  # (k, 3)

        # 最小二乘法求解 C = (A^T A + epsI)^{-1} (A^T b)
        A_T = A.T  # (3, k)
        gram = A_T @ A + eps * torch.eye(dim, device=device)
        C[i] = torch.linalg.solve(gram, A_T @ b).T  # (3, 3)

    return C


# def compute_affine_velocity_gpu(x_t, x_tp1, dt, k=10, eps=1e-6):
#     """
#     GPU加速的仿射速度场计算 (v ≈ C x)
#
#     参数:
#         x_t (torch.Tensor): 当前帧位置 (N, 3), 必须在GPU上
#         x_tp1 (torch.Tensor): 下一帧位置 (N, 3), 必须与x_t同设备
#         dt (float): 时间步长
#         k (int): 邻近粒子数
#
#     返回:
#         C (torch.Tensor): 仿射速度场 (N, 3, 3)
#     """
#     assert x_t.device == x_tp1.device, "输入张量必须在同一设备上"
#     N, dim = x_t.shape
#     device = x_t.device
#     v = (x_tp1 - x_t) / dt  # (N, 3)
#
#     # 1. Faiss GPU邻居搜索
#     x_t_np = x_t.cpu().numpy().astype('float32')  # Faiss需要float32
#     res = faiss.StandardGpuResources()
#     index = faiss.IndexFlatL2(dim)
#     gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
#     gpu_index.add(x_t_np)
#     _, indices = gpu_index.search(x_t_np, k + 1)  # (N, k+1)
#     indices = torch.from_numpy(indices).to(device)
#
#     # 2. 向量化计算差分
#     k_actual = min(k, N - 1)
#     neighbors = indices[:, 1:k_actual + 1]  # (N, k_actual)
#     A = x_t[neighbors]  # (N, k_actual, 3)
#     b = v[neighbors]  # (N, k_actual, 3)
#
#     # 3. 批量最小二乘求解
#     A_T = A.transpose(1, 2)  # (N, 3, k_actual)
#     gram = torch.matmul(A_T, A) + eps * torch.eye(3, device=device).unsqueeze(0)
#     C = torch.linalg.solve(gram, torch.matmul(A_T, b))  # (N, 3, 3)
#
#     return C.transpose(1, 2)  # 调整为(N, 3, 3)输出

def knn_weights(bones, pts, K=5):
    dist = torch.norm(pts[:, None] - bones, dim=-1)  # (n_pts, n_bones)
    _, indices = torch.topk(dist, K, dim=-1, largest=False)
    bones_selected = bones[indices]  # (N, k, 3)
    dist = torch.norm(bones_selected - pts[:, None], dim=-1)  # (N, k)
    weights = 1 / (dist + 1e-6)
    weights = weights / weights.sum(dim=-1, keepdim=True)  # (N, k)
    weights_all = torch.zeros((pts.shape[0], bones.shape[0]), device=pts.device)  # TODO: prevent init new one
    # weights_all[torch.arange(pts.shape[0])[:, None], indices] = weights
    weights_all[torch.arange(pts.shape[0], device=pts.device)[:, None], indices] = weights
    return weights_all

def compute_affine_velocity_gpu(x_t, x_tp1, dt, k=10, eps=1e-6):
    """
    GPU加速的仿射速度场计算 (显式中心化版本: v_i - v_p ≈ C_p (x_i - x_p))

    参数:
        x_t (torch.Tensor): 当前帧位置 (N, 3), 必须在GPU上
        x_tp1 (torch.Tensor): 下一帧位置 (N, 3), 必须与x_t同设备
        dt (float): 时间步长
        k (int): 邻近粒子数

    返回:
        C (torch.Tensor): 仿射速度场 (N, 3, 3)
    """
    assert x_t.device == x_tp1.device, "输入张量必须在同一设备上"
    N, dim = x_t.shape
    device = x_t.device

    # 计算粒子速度 (v_p = (x_tp1 - x_t) / dt)
    v = (x_tp1 - x_t) / dt  # (N, 3)

    # 1. Faiss GPU邻居搜索
    x_t_np = x_t.cpu().numpy().astype('float32')  # Faiss需要float32
    res = faiss.StandardGpuResources()
    index = faiss.IndexFlatL2(dim)
    gpu_index = faiss.index_cpu_to_gpu(res, 0, index)
    gpu_index.add(x_t_np)
    _, indices = gpu_index.search(x_t_np, k + 1)  # (N, k+1)
    indices = torch.from_numpy(indices).to(device)

    # 2. 显式中心化计算
    k_actual = min(k, N - 1)
    neighbors = indices[:, 1:k_actual + 1]  # (N, k_actual), 排除自身

    # 中心化相对位置和速度
    A = x_t[neighbors] - x_t.unsqueeze(1)  # (N, k_actual, 3): x_i - x_p
    b = v[neighbors] - v.unsqueeze(1)  # (N, k_actual, 3): v_i - v_p

    # 3. 批量最小二乘求解 C_p
    A_T = A.transpose(1, 2)  # (N, 3, k_actual)
    gram = torch.matmul(A_T, A) + eps * torch.eye(3, device=device).unsqueeze(0)
    C = torch.linalg.solve(gram, torch.matmul(A_T, b))  # (N, 3, 3)

    return C.transpose(1, 2)  # 调整为(N, 3, 3)输出

def compute_volume(X, F, V0_method="knn"):
    """
    计算体积 V = V0 * det(F)，基于初始体积V0和形变梯度F

    参数:
        X (torch.Tensor): 参考构型位置 (N, 3)
        F (torch.Tensor): 形变梯度 (N, 3, 3)
        V0_method (str): 初始体积计算方式 ("knn"或"fixed")

    返回:
        V (torch.Tensor): 体积 (N,)
    """
    device = X.device
    det_F = torch.det(F)  # (N,)

    # 计算初始体积V0
    if V0_method == "knn":
        # 基于邻近粒子间距估算V0 (更物理)
        knn = NearestNeighbors(n_neighbors=2).fit(X.cpu().numpy())
        distances = knn.kneighbors(X.cpu().numpy())[0][:, 1]  # 最近邻距离
        h = torch.from_numpy(distances).float().to(device)
        V0 = h ** 3  # 假设立方体邻域
    else:
        # 固定初始体积 (简化计算)
        V0 = torch.ones(len(X), device=device) * 0.001  # 示例值

    V = V0 * det_F
    return V