import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 计算旋转矩阵，绕指定轴旋转
def rotation_matrix(axis, theta):
    """生成绕任意轴旋转的旋转矩阵"""
    axis = axis / torch.norm(axis)  # 单位化轴
    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    cross = torch.tensor([  # 交叉乘积矩阵
        [0, -axis[2], axis[1]],
        [axis[2], 0, -axis[0]],
        [-axis[1], axis[0], 0]
    ], device=device)  # 确保cross矩阵也在同一设备
    I = torch.eye(3, device=device)  # 单位矩阵

    # 旋转矩阵公式
    R = cos_theta * I + (1 - cos_theta) * torch.outer(axis, axis) + sin_theta * cross
    return R

# 执行扰动操作
def perturb_structure(structure, rotation_stddev=0.01, bond_stddev=0.02):
    # 这里，structure的维度是[length, 4, 3]，对应四个骨架原子（N, CA, C, O）
    
    length = structure.shape[0]  # 获取结构的数量（length）
    
    perturbed_structure = structure.clone()  # 保留原始结构
    
    # 进行批量化处理
    N = structure[:, 0, :].to(device)
    CA = structure[:, 1, :].to(device)
    C = structure[:, 2, :].to(device)
    O = structure[:, 3, :].to(device)

    # 计算 N-CA 和 C-O 向量
    N_CA = N - CA
    CA_C = C - CA
    C_O = O - C  # 新增：计算 C 和 O 之间的相对向量

    # 计算旋转轴
    rotation_axis = torch.cross(N_CA, CA_C, dim=1)
    rotation_axis = rotation_axis / torch.norm(rotation_axis, dim=1, keepdim=True)  # 单位化

    # 旋转角度，使用高斯分布采样
    rotation_angles = torch.normal(torch.tensor(0.0, device=device), rotation_stddev, size=(length,))

    # 生成旋转矩阵
    rotation_matrices = torch.stack([rotation_matrix(axis, theta) for axis, theta in zip(rotation_axis, rotation_angles)], dim=0)

    # 旋转 N 原子
    N_rot = torch.bmm(rotation_matrices, (N - CA).unsqueeze(-1))  # 旋转矩阵 * 向量
    N_rot = N_rot.squeeze(-1) + CA  # 恢复 N 原子的位移

    # 执行旋转：旋转 C 原子
    C_rot = torch.bmm(rotation_matrices, (C - CA).unsqueeze(-1))  # 旋转矩阵 * 向量
    C_rot = C_rot.squeeze(-1) + CA  # 恢复 C 原子的位移

    # 保持 C 和 O 之间的相对向量不变：C_O 朝向和长度不变
    C_O_dist = torch.norm(C_O, dim=1)  # 计算C和O之间的距离
    C_O_unit_vector = C_O / C_O_dist.unsqueeze(1)  # 单位化C和O之间的向量

    # 计算旋转后的 C 和 O 之间的相对向量
    O_new = C_rot + C_O_unit_vector * C_O_dist.unsqueeze(1)

    # 扰动键长
    N_CA_dist = torch.norm(N - CA, dim=1)
    CA_C_dist = torch.norm(CA - C, dim=1)

    N_CA_dist_perturbed = N_CA_dist * (1 + torch.normal(torch.tensor(0.0, device=device), bond_stddev, size=(length,))).to(device)
    CA_C_dist_perturbed = CA_C_dist * (1 + torch.normal(torch.tensor(0.0, device=device), bond_stddev, size=(length,))).to(device)

    # 更新 N 和 C 原子的坐标
    N_new = CA + (N - CA) * (N_CA_dist_perturbed.unsqueeze(1) / N_CA_dist.unsqueeze(1))
    C_new = CA + (C_rot - CA) * (CA_C_dist_perturbed.unsqueeze(1) / CA_C_dist.unsqueeze(1))

    perturbed_structure[:, 0] = N_new
    perturbed_structure[:, 2] = C_new
    perturbed_structure[:, 3] = O_new  # 更新 O 的位置

    return perturbed_structure
