import torch
import torch.nn.functional as F

def _make_odd(k):
    """确保 Kernel Size 是奇数，保证卷积中心对齐"""
    return k if (k % 2 == 1) else max(1, k - 1)

def gaussian_smooth(x, kernel_size=5, sigma=1.0):
    """
    x: [B, S, D]
    使用 Replicate Padding 避免边界产生虚假加速度
    """
    B, S, D = x.shape
    if S <= 1:
        return x
        
    # 动态调整 kernel_size，防止 kernel 比序列还长
    kernel_size = min(int(kernel_size), S)
    kernel_size = _make_odd(kernel_size)
    
    # 准备卷积核
    k = torch.arange(kernel_size, device=x.device).float() - (kernel_size - 1) // 2
    kernel = torch.exp(-k**2 / (2 * sigma**2))
    kernel = kernel / kernel.sum()
    kernel = kernel.view(1, 1, -1)
    
    # 变换维度适配 conv1d: [B, D, S]
    x_perm = x.permute(0, 2, 1)
    
    # 使用 replicate 填充 (关键改进)
    pad_len = kernel_size // 2
    # F.pad 对最后维度填充：(pad_left, pad_right)
    x_padded = F.pad(x_perm, (pad_len, pad_len), mode='replicate')
    
    # Depthwise Convolution
    weight = kernel.repeat(D, 1, 1)
    out = F.conv1d(x_padded, weight, groups=D)
    
    return out.permute(0, 2, 1) # [B, S, D]

def compute_target_acceleration(residual, dt=1.0, sigma=1.0):
    # """
    # 推荐版本：结合了 Scheme 2 的计算逻辑和 Scheme 1 的鲁棒性检查
    # """
    B, S, D = residual.shape
    
    # 1. 极短序列保护 (来自 Scheme 1)
    if S < 3:
        # 序列太短无法计算二阶导数，返回全0
        return torch.zeros_like(residual)
    
    # 2. 高斯平滑 (使用 Replicate 模式)
    residual_smooth = gaussian_smooth(residual, kernel_size=31, sigma=sigma) 
    
    # 3. 计算速度 (一阶导) - 使用 torch.gradient (Scheme 2)
    # edge_order=2 保证边界精度
    velocity_res = torch.gradient(residual_smooth, spacing=dt, dim=1, edge_order=2)[0]
    
    # 4. 计算加速度 (二阶导)
    acceleration_res = torch.gradient(velocity_res, spacing=dt, dim=1, edge_order=2)[0]
    

    return acceleration_res,velocity_res

    # B, S, D = residual.shape
    # if S < 2: 
    #     return torch.zeros_like(residual)
    
    # # 1. (可选) 高斯平滑，保持与原函数一致的鲁棒性
    # # 假设 gaussian_smooth 存在
    # residual_smooth = gaussian_smooth(residual, kernel_size=5, sigma=sigma) 
    
    # # 2. 计算速度 (一阶导)
    # # torch.gradient(..., dim=1)[0] 返回沿时间轴的梯度
    # velocity_res = torch.gradient(residual_smooth, spacing=dt, dim=1, edge_order=2)[0]
    
    # return velocity_res # 返回速度 V