import torch


def normalize_last_dim(v):
    norm = torch.linalg.norm(v, dim=-1, keepdim=True)
    return v / torch.clamp(norm, min=1e-8, max=None)


def bond_angles(a, b, c):

    b0 = b - a
    b1 = c - a
    b0, b1 = map(normalize_last_dim, (b0, b1))
    cos_angle = torch.linalg.vecdot(b0, b1, dim=-1)
    cross = torch.linalg.cross(b0, b1, dim=-1)
    sin_angle = torch.linalg.norm(cross, dim=-1)
    return torch.atan2(sin_angle, cos_angle)


def signed_dihedral_angle(a, b, c, d):

    b0 = b - a
    b1 = c - b
    b2 = d - c
    n1 = torch.linalg.cross(b0, b1)
    n2 = torch.linalg.cross(b1, b2)
    n1, n2 = map(normalize_last_dim, (n1, n2))

    cos_angle = torch.linalg.vecdot(n1, n2, dim=-1)
    n1_cross_n2 = torch.linalg.cross(n1, n2, dim=-1)
    sin_angle_magnitude = torch.linalg.norm(n1_cross_n2, dim=-1)
    sign = torch.sign(torch.linalg.vecdot(n1_cross_n2, b1))
    return torch.atan2(sign * sin_angle_magnitude, cos_angle)
