import torch
from src.datasets.bio_tokenizer import BioTokenizer

rigid_group_atom_positions = {
    'A': torch.tensor([  # ALA
        [-0.525,  1.363,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000, -0.000],  # C
        [ 0.627,  1.062,  0.000],  # O
    ]),
    'C': torch.tensor([  # CYS
        [-0.522,  1.362,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.524,  0.000,  0.000],  # C
        [ 0.625,  1.062, -0.000],  # O
    ]),
    'D': torch.tensor([  # ASP
        [-0.525,  1.362,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.527,  0.000,  0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'E': torch.tensor([  # GLU
        [-0.528,  1.361,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000,  0.000],  # C
        [ 0.626,  1.062,  0.000],  # O
    ]),
    'F': torch.tensor([  # PHE
        [-0.518,  1.363,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.524, -0.000, -0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'G': torch.tensor([  # GLY
        [-0.572,  1.337,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.517, -0.000, -0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'H': torch.tensor([  # HIS
        [-0.527,  1.360,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000,  0.000],  # C
        [ 0.625,  1.063,  0.000],  # O
    ]),
    'I': torch.tensor([  # ILE
        [-0.493,  1.373,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.527, -0.000, -0.000],  # C
        [ 0.627,  1.062, -0.000],  # O
    ]),
    'K': torch.tensor([  # LYS
        [-0.526,  1.362,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000, -0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'L': torch.tensor([  # LEU
        [-0.520,  1.363,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000, -0.000],  # C
        [ 0.625,  1.063, -0.000],  # O
    ]),
    'M': torch.tensor([  # MET
        [-0.521,  1.364,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000,  0.000],  # C
        [ 0.625,  1.062, -0.000],  # O
    ]),
    'N': torch.tensor([  # ASN
        [-0.536,  1.357,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000, -0.000],  # C
        [ 0.625,  1.062,  0.000],  # O
    ]),
    'P': torch.tensor([  # PRO
        [-0.566,  1.351,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.527, -0.000,  0.000],  # C
        [ 0.621,  1.066,  0.000],  # O
    ]),
    'Q': torch.tensor([  # GLN
        [-0.526,  1.361,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000,  0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'R': torch.tensor([  # ARG
        [-0.524,  1.362,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000, -0.000],  # C
        [ 0.626,  1.062,  0.000],  # O
    ]),
    'S': torch.tensor([  # SER
        [-0.529,  1.360,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000, -0.000],  # C
        [ 0.626,  1.062, -0.000],  # O
    ]),
    'T': torch.tensor([  # THR
        [-0.517,  1.364,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.526, -0.000, -0.000],  # C
        [ 0.626,  1.062,  0.000],  # O
    ]),
    'V': torch.tensor([  # VAL
        [-0.494,  1.373,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.527, -0.000, -0.000],  # C
        [ 0.627,  1.062, -0.000],  # O
    ]),
    'W': torch.tensor([  # TRP
        [-0.521,  1.363,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.525, -0.000,  0.000],  # C
        [ 0.627,  1.062,  0.000],  # O
    ]),
    'Y': torch.tensor([  # TYR
        [-0.522,  1.362,  0.000],  # N
        [ 0.000,  0.000,  0.000],  # Cα
        [ 1.524, -0.000, -0.000],  # C
        [ 0.627,  1.062, -0.000],  # O
    ])
}

tokenizer = BioTokenizer()
rigid_group_tensor = torch.stack([rigid_group_atom_positions[res] for res in tokenizer.alphabet_protein])