import math
import torch

def random_rotate_translate_batch(batch_xyz, max_rot_deg=10.0, max_translate=0.02):
    """
    对 batch 的点云做随机旋转和平移。
    batch_xyz: Tensor (B, T, N, 3) 或 (B, L, N, 3)
    返回：同 shape 的扰动后 Tensor（不原地修改）
    """
    B, T, N, C = batch_xyz.shape
    assert C >= 3, "Expect last dim >=3 (xyz + optional features)"
    device = batch_xyz.device
    out = batch_xyz.clone()

    # 随机旋转绕 z 轴（也可扩展到任意轴）
    # 也可以对每个样本、每个帧不同，这里按样本级随机
    max_rad = math.radians(max_rot_deg)
    angles = (torch.rand(B, device=device) * 2 - 1) * max_rad  # [-max_rad, max_rad]
    cos = torch.cos(angles)
    sin = torch.sin(angles)
    R = torch.zeros((B, 3, 3), device=device)
    # 仅绕 z 旋转（口部点云以 x-y 平面为主），如果你需要 3D 随机旋转可以改这里
    R[:, 0, 0] = cos; R[:, 0, 1] = -sin; R[:, 0, 2] = 0
    R[:, 1, 0] = sin; R[:, 1, 1] = cos;  R[:, 1, 2] = 0
    R[:, 2, 0] = 0;   R[:, 2, 1] = 0;    R[:, 2, 2] = 1

    # 平移
    trans = (torch.rand(B, 3, device=device) * 2 - 1) * max_translate  # (B,3)

    # apply to each frame
    for b in range(B):
        # frames (T, N, 3) -> (T, N, 3) after matmul
        coords = out[b, :, :, :3]  # (T,N,3)
        # reshape to (T*N,3)
        coords_flat = coords.reshape(-1, 3)
        coords_rot = torch.matmul(coords_flat, R[b].t())  # (T*N,3)
        coords_rot = coords_rot + trans[b].unsqueeze(0)
        out[b, :, :, :3] = coords_rot.reshape(T, N, 3)

    return out

def fgsm_attack(model, loss_fn, orig_inputs, targets, input_lens, target_lens,
                eps=0.02, alpha=None, device='cuda'):
    """
    单步FGSM对点云坐标产生对抗样本（不修改模型）
    orig_inputs: Tensor (B, T, N, 3) 要求 float, requires_grad=False
    targets: targets tensor (or whatever criterion expects)
    返回：adv_inputs（同shape）
    注意：这个函数做的是单步（FGSM）。若要PGD请看下方PGD示例。
    """
    model.eval()
    adv_inputs = orig_inputs.clone().detach().to(device)
    adv_inputs.requires_grad_()

    logits = model(adv_inputs)  # (B, L, C) or model expects (B, T, N, 3) - 保持与你的 forward 一致
    # 在你的train里，criterion 接受 (log_probs, targets, input_lengths, target_lengths)
    # logits 需要 transpose 到 (L,B,C) 如果你的 criterion 需要
    if logits.dim() == 3:  # assume (B,L,C)
        logp = logits.transpose(0, 1).log_softmax(-1)
    else:
        logp = logits.log_softmax(-1)

    loss = loss_fn(logp, targets, input_lens, target_lens)
    loss.backward()

    # grads on inputs
    grads = adv_inputs.grad.data
    sign = grads.sign()
    if alpha is None:
        alpha = eps
    adv = adv_inputs + alpha * sign
    adv = torch.min(torch.max(adv, orig_inputs - eps), orig_inputs + eps)  # clip to eps ball
    adv = adv.detach()
    model.train()
    return adv

def pgd_attack(model, loss_fn, orig_inputs, targets, input_lens, target_lens,
               eps=0.02, alpha=0.005, steps=3, device='cuda'):
    """
    简单 PGD：多步在 L_inf 球内更新并投影。
    """
    adv = orig_inputs.clone().detach().to(device)
    adv = adv + torch.randn_like(adv) * (eps * 0.001)  # small init noise
    adv = adv.clamp(min=orig_inputs - eps, max=orig_inputs + eps)
    adv.requires_grad_()
    for _ in range(steps):
        logits = model(adv)
        if logits.dim() == 3:
            logp = logits.transpose(0, 1).log_softmax(-1)
        else:
            logp = logits.log_softmax(-1)
        loss = loss_fn(logp, targets, input_lens, target_lens)
        model.zero_grad()
        if adv.grad is not None:
            adv.grad.zero_()
        loss.backward()
        grad = adv.grad.data
        adv = adv + alpha * grad.sign()
        adv = torch.min(torch.max(adv, orig_inputs - eps), orig_inputs + eps)
        adv = adv.detach()
        adv.requires_grad_()
    model.train()
    return adv
