import torch

def generate_Cbeta(N, Ca, C):
    b = Ca - N
    c = C - Ca
    a = torch.cross(b, c, dim=-1)
    Cb = -0.58273431*a + 0.56802827*b - 0.54067466*c + Ca
    return Cb

def calc_distogram(pos, lower):
    lower = lower.to(pos.device)
    dists_2d = torch.linalg.norm(
        pos[:, :, None, :] - pos[:, None, :, :], axis=-1)[..., None]
    upper = torch.cat([lower[1:], lower.new_tensor([1e8])], dim=-1)
    dgram = ((dists_2d > lower) * (dists_2d < upper)).type(pos.dtype)
    return dgram

def build_distogram_lower(splits, steps):
    lower = torch.zeros(0)
    for i in range(len(splits) - 1):
        cur = torch.linspace(splits[i], splits[i + 1], steps[i])
        if i > 0:
            cur = cur[1:]
        lower = torch.concat((lower, cur))
    return lower