####################################################################################################
# CHAINS
####################################################################################################

chain_types = [
    "PROTEIN",
    "DNA",
    "RNA",
    "NONPOLYMER",
]
chain_type_ids = {chain: i for i, chain in enumerate(chain_types)}

out_types = [
    "dna_protein",
    "rna_protein",
    "ligand_protein",
    "dna_ligand",
    "rna_ligand",
    "intra_ligand",
    "intra_dna",
    "intra_rna",
    "intra_protein",
    "protein_protein",
    "modified",
]

out_types_weights_af3 = {
    "dna_protein": 10.0,
    "rna_protein": 10.0,
    "ligand_protein": 10.0,
    "dna_ligand": 5.0,
    "rna_ligand": 5.0,
    "intra_ligand": 20.0,
    "intra_dna": 4.0,
    "intra_rna": 16.0,
    "intra_protein": 20.0,
    "protein_protein": 20.0,
    "modified": 0.0,
}

out_types_weights = {
    "dna_protein": 5.0,
    "rna_protein": 5.0,
    "ligand_protein": 20.0,
    "dna_ligand": 2.0,
    "rna_ligand": 2.0,
    "intra_ligand": 20.0,
    "intra_dna": 2.0,
    "intra_rna": 8.0,
    "intra_protein": 20.0,
    "protein_protein": 20.0,
    "modified": 0.0,
}


out_single_types = ["protein", "ligand", "dna", "rna"]

clash_types = [
    "dna_protein",
    "rna_protein",
    "ligand_protein",
    "protein_protein",
    "dna_ligand",
    "rna_ligand",
    "ligand_ligand",
    "rna_dna",
    "dna_dna",
    "rna_rna",
]

chain_types_to_clash_type = {
    frozenset(("PROTEIN", "DNA")): "dna_protein",
    frozenset(("PROTEIN", "RNA")): "rna_protein",
    frozenset(("PROTEIN", "NONPOLYMER")): "ligand_protein",
    frozenset(("PROTEIN",)): "protein_protein",
    frozenset(("NONPOLYMER", "DNA")): "dna_ligand",
    frozenset(("NONPOLYMER", "RNA")): "rna_ligand",
    frozenset(("NONPOLYMER",)): "ligand_ligand",
    frozenset(("DNA", "RNA")): "rna_dna",
    frozenset(("DNA",)): "dna_dna",
    frozenset(("RNA",)): "rna_rna",
}

chain_type_to_out_single_type = {
    "PROTEIN": "protein",
    "DNA": "dna",
    "RNA": "rna",
    "NONPOLYMER": "ligand",
}
####################################################################################################
# RESIDUES & TOKENS
####################################################################################################


canonical_tokens = [
    "ALA",
    "ARG",
    "ASN",
    "ASP",
    "CYS",
    "GLN",
    "GLU",
    "GLY",
    "HIS",
    "ILE",
    "LEU",
    "LYS",
    "MET",
    "PHE",
    "PRO",
    "SER",
    "THR",
    "TRP",
    "TYR",
    "VAL",
    "UNK",  # unknown protein token
]

tokens = [
    "<pad>",
    "-",
    *canonical_tokens,
    "A",
    "G",
    "C",
    "U",
    "N",  # unknown rna token
    "DA",
    "DG",
    "DC",
    "DT",
    "DN",  # unknown dna token
]


token_ids = {token: i for i, token in enumerate(tokens)}
num_tokens = len(tokens)
unk_token = {"PROTEIN": "UNK", "DNA": "DN", "RNA": "N"}
unk_token_ids = {m: token_ids[t] for m, t in unk_token.items()}

prot_letter_to_token = {
    "A": "ALA",
    "R": "ARG",
    "N": "ASN",
    "D": "ASP",
    "C": "CYS",
    "E": "GLU",
    "Q": "GLN",
    "G": "GLY",
    "H": "HIS",
    "I": "ILE",
    "L": "LEU",
    "K": "LYS",
    "M": "MET",
    "F": "PHE",
    "P": "PRO",
    "S": "SER",
    "T": "THR",
    "W": "TRP",
    "Y": "TYR",
    "V": "VAL",
    "X": "UNK",
    "J": "UNK",
    "B": "UNK",
    "Z": "UNK",
    "O": "UNK",
    "U": "UNK",
    "-": "-",
}

prot_token_to_letter = {v: k for k, v in prot_letter_to_token.items()}
prot_token_to_letter["UNK"] = "X"

rna_letter_to_token = {
    "A": "A",
    "G": "G",
    "C": "C",
    "U": "U",
    "N": "N",
}
rna_token_to_letter = {v: k for k, v in rna_letter_to_token.items()}

dna_letter_to_token = {
    "A": "DA",
    "G": "DG",
    "C": "DC",
    "T": "DT",
    "N": "DN",
}
dna_token_to_letter = {v: k for k, v in dna_letter_to_token.items()}

####################################################################################################
# ATOMS
####################################################################################################

num_elements = 128

chirality_types = [
    "CHI_UNSPECIFIED",
    "CHI_TETRAHEDRAL_CW",
    "CHI_TETRAHEDRAL_CCW",
    "CHI_SQUAREPLANAR",
    "CHI_OCTAHEDRAL",
    "CHI_TRIGONALBIPYRAMIDAL",
    "CHI_OTHER",
]
chirality_type_ids = {chirality: i for i, chirality in enumerate(chirality_types)}
unk_chirality_type = "CHI_OTHER"

hybridization_map = [
    "S",
    "SP",
    "SP2",
    "SP2D",
    "SP3",
    "SP3D",
    "SP3D2",
    "OTHER",
    "UNSPECIFIED",
]
hybridization_type_ids = {hybrid: i for i, hybrid in enumerate(hybridization_map)}
unk_hybridization_type = "UNSPECIFIED"

# fmt: off
ref_atoms = {
    "PAD": [],
    "UNK": ["N", "CA", "C", "O", "CB"],
    "-": [],
    "ALA": ["N", "CA", "C", "O", "CB"],
    "ARG": ["N", "CA", "C", "O", "CB", "CG", "CD", "NE", "CZ", "NH1", "NH2"],
    "ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2"],
    "ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2"],
    "CYS": ["N", "CA", "C", "O", "CB", "SG"],
    "GLN": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "NE2"],
    "GLU": ["N", "CA", "C", "O", "CB", "CG", "CD", "OE1", "OE2"],
    "GLY": ["N", "CA", "C", "O"],
    "HIS": ["N", "CA", "C", "O", "CB", "CG", "ND1", "CD2", "CE1", "NE2"],
    "ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1"],
    "LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2"],
    "LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ"],
    "MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE"],
    "PHE": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ"],
    "PRO": ["N", "CA", "C", "O", "CB", "CG", "CD"],
    "SER": ["N", "CA", "C", "O", "CB", "OG"],
    "THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2"],
    "TRP": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "NE1", "CE2", "CE3", "CZ2", "CZ3", "CH2"],  # noqa: E501
    "TYR": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "CE1", "CE2", "CZ", "OH"],
    "VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2"],
    "A": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"],  # noqa: E501
    "G": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"],  # noqa: E501
    "C": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"],  # noqa: E501
    "U": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C6"],  # noqa: E501
    "N": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"],  # noqa: E501
    "DA": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "N6", "N1", "C2", "N3", "C4"],  # noqa: E501
    "DG": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N9", "C8", "N7", "C5", "C6", "O6", "N1", "C2", "N2", "N3", "C4"],  # noqa: E501
    "DC": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "N4", "C5", "C6"],  # noqa: E501
    "DT": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'", "N1", "C2", "O2", "N3", "C4", "O4", "C5", "C7", "C6"],  # noqa: E501
    "DN": ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "C1'"]
}

protein_backbone_atom_names = ["N", "CA", "C", "O"]
nucleic_backbone_atom_names = ["P", "OP1", "OP2", "O5'", "C5'", "C4'", "O4'", "C3'", "O3'", "C2'", "O2'", "C1'"]

protein_backbone_atom_index = {name: i for i, name in enumerate(protein_backbone_atom_names)}
nucleic_backbone_atom_index = {name: i for i, name in enumerate(nucleic_backbone_atom_names)}

ref_symmetries = {
    "PAD": [],
    "ALA": [],
    "ARG": [],
    "ASN": [],
    "ASP": [[(6, 7), (7, 6)]],
    "CYS": [],
    "GLN": [],
    "GLU": [[(7, 8), (8, 7)]],
    "GLY": [],
    "HIS": [],
    "ILE": [],
    "LEU": [],
    "LYS": [],
    "MET": [],
    "PHE": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
    "PRO": [],
    "SER": [],
    "THR": [],
    "TRP": [],
    "TYR": [[(6, 7), (7, 6), (8, 9), (9, 8)]],
    "VAL": [],
    "A": [[(1, 2), (2, 1)]],
    "G": [[(1, 2), (2, 1)]],
    "C": [[(1, 2), (2, 1)]],
    "U": [[(1, 2), (2, 1)]],
    #"N": [[(1, 2), (2, 1)]],
    "DA": [[(1, 2), (2, 1)]],
    "DG": [[(1, 2), (2, 1)]],
    "DC": [[(1, 2), (2, 1)]],
    "DT": [[(1, 2), (2, 1)]],
    #"DN": [[(1, 2), (2, 1)]]
}


res_to_center_atom = {
    "UNK": "CA",
    "ALA": "CA",
    "ARG": "CA",
    "ASN": "CA",
    "ASP": "CA",
    "CYS": "CA",
    "GLN": "CA",
    "GLU": "CA",
    "GLY": "CA",
    "HIS": "CA",
    "ILE": "CA",
    "LEU": "CA",
    "LYS": "CA",
    "MET": "CA",
    "PHE": "CA",
    "PRO": "CA",
    "SER": "CA",
    "THR": "CA",
    "TRP": "CA",
    "TYR": "CA",
    "VAL": "CA",
    "A": "C1'",
    "G": "C1'",
    "C": "C1'",
    "U": "C1'",
    "N": "C1'",
    "DA": "C1'",
    "DG": "C1'",
    "DC": "C1'",
    "DT": "C1'",
    "DN": "C1'"
}

res_to_disto_atom = {
    "UNK": "CB",
    "ALA": "CB",
    "ARG": "CB",
    "ASN": "CB",
    "ASP": "CB",
    "CYS": "CB",
    "GLN": "CB",
    "GLU": "CB",
    "GLY": "CA",
    "HIS": "CB",
    "ILE": "CB",
    "LEU": "CB",
    "LYS": "CB",
    "MET": "CB",
    "PHE": "CB",
    "PRO": "CB",
    "SER": "CB",
    "THR": "CB",
    "TRP": "CB",
    "TYR": "CB",
    "VAL": "CB",
    "A": "C4",
    "G": "C4",
    "C": "C2",
    "U": "C2",
    "N": "C1'",
    "DA": "C4",
    "DG": "C4",
    "DC": "C2",
    "DT": "C2",
    "DN": "C1'"
}

res_to_center_atom_id = {
    res: ref_atoms[res].index(atom)
    for res, atom in res_to_center_atom.items()
}

res_to_disto_atom_id = {
    res: ref_atoms[res].index(atom)
    for res, atom in res_to_disto_atom.items()
}

# fmt: on

####################################################################################################
# BONDS
####################################################################################################

atom_interface_cutoff = 5.0
interface_cutoff = 15.0

bond_types = [
    "OTHER",
    "SINGLE",
    "DOUBLE",
    "TRIPLE",
    "AROMATIC",
    "COVALENT",
]

bond_type_ids = {bond: i for i, bond in enumerate(bond_types)}
unk_bond_type = "OTHER"

bond_types_mol_enc = [
    2,  # DOUBLE bond with NONE direction
    3,  # AROMATIC bond with NONE direction
    4,  # SINGLE bond with ENDUPRIGHT direction (wedge up)
    5,  # SINGLE bond with ENDDOWNRIGHT direction (dash down)
    6,  # AROMATIC bond with ENDDOWNRIGHT direction
    7,  # AROMATIC bond with ENDUPRIGHT direction
    8,  # DOUBLE bond with EITHERDOUBLE direction (cis/trans)
    9,  # DATIVE bond with NONE direction
    10,  # TRIPLE bond with NONE direction
    11,  # SINGLE bond with NONE direction
]
bond_type_ids_mol_enc = {bond: i for i, bond in enumerate(bond_types_mol_enc)}

####################################################################################################
# Contacts
####################################################################################################


pocket_contact_info = {
    "UNSPECIFIED": 0,
    "UNSELECTED": 1,
    "POCKET": 2,
    "BINDER": 3,
}

contact_conditioning_info = {
    "UNSPECIFIED": 0,
    "UNSELECTED": 1,
    "POCKET>BINDER": 2,
    "BINDER>POCKET": 3,
    "CONTACT": 4,
}


####################################################################################################
# MSA
####################################################################################################

max_msa_seqs = 16384
max_paired_seqs = 8192


####################################################################################################
# CHUNKING
####################################################################################################

chunk_size_threshold = 384

####################################################################################################
# Method conditioning
####################################################################################################

# Methods
method_types_ids = {
    "MD": 0,
    "X-RAY DIFFRACTION": 1,
    "ELECTRON MICROSCOPY": 2,
    "SOLUTION NMR": 3,
    "SOLID-STATE NMR": 4,
    "NEUTRON DIFFRACTION": 4,
    "ELECTRON CRYSTALLOGRAPHY": 4,
    "FIBER DIFFRACTION": 4,
    "POWDER DIFFRACTION": 4,
    "INFRARED SPECTROSCOPY": 4,
    "FLUORESCENCE TRANSFER": 4,
    "EPR": 4,
    "THEORETICAL MODEL": 4,
    "SOLUTION SCATTERING": 4,
    "OTHER": 4,
    "AFDB": 5,
    "BOLTZ-1": 6,
    "FUTURE1": 7,  # Placeholder for future supervision sources
    "FUTURE2": 8,
    "FUTURE3": 9,
    "FUTURE4": 10,
    "FUTURE5": 11,
}
method_types_ids = {k.lower(): v for k, v in method_types_ids.items()}
num_method_types = len(set(method_types_ids.values()))

# Temperature
temperature_bins = [(265, 280), (280, 295), (295, 310)]
temperature_bins_ids = {temp: i for i, temp in enumerate(temperature_bins)}
temperature_bins_ids["other"] = len(temperature_bins)
num_temp_bins = len(temperature_bins_ids)


# pH
ph_bins = [(0, 6), (6, 8), (8, 14)]
ph_bins_ids = {ph: i for i, ph in enumerate(ph_bins)}
ph_bins_ids["other"] = len(ph_bins)
num_ph_bins = len(ph_bins_ids)

####################################################################################################
# VDW_RADII
####################################################################################################

# fmt: off
vdw_radii = [
    1.2, 1.4, 2.2, 1.9, 1.8, 1.7, 1.6, 1.55, 1.5, 1.54,
    2.4, 2.2, 2.1, 2.1, 1.95, 1.8, 1.8, 1.88, 2.8, 2.4,
    2.3, 2.15, 2.05, 2.05, 2.05, 2.05, 2.0, 2.0, 2.0, 2.1,
    2.1, 2.1, 2.05, 1.9, 1.9, 2.02, 2.9, 2.55, 2.4, 2.3,
    2.15, 2.1, 2.05, 2.05, 2.0, 2.05, 2.1, 2.2, 2.2, 2.25,
    2.2, 2.1, 2.1, 2.16, 3.0, 2.7, 2.5, 2.48, 2.47, 2.45,
    2.43, 2.42, 2.4, 2.38, 2.37, 2.35, 2.33, 2.32, 2.3, 2.28,
    2.27, 2.25, 2.2, 2.1, 2.05, 2.0, 2.0, 2.05, 2.1, 2.05,
    2.2, 2.3, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.4,
    2.0, 2.3, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
    2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
    2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0
]
# fmt: on

####################################################################################################
# Excluded ligands
####################################################################################################

ligand_exclusion = {
    "144",
    "15P",
    "1PE",
    "2F2",
    "2JC",
    "3HR",
    "3SY",
    "7N5",
    "7PE",
    "9JE",
    "AAE",
    "ABA",
    "ACE",
    "ACN",
    "ACT",
    "ACY",
    "AZI",
    "BAM",
    "BCN",
    "BCT",
    "BDN",
    "BEN",
    "BME",
    "BO3",
    "BTB",
    "BTC",
    "BU1",
    "C8E",
    "CAD",
    "CAQ",
    "CBM",
    "CCN",
    "CIT",
    "CL",
    "CLR",
    "CM",
    "CMO",
    "CO3",
    "CPT",
    "CXS",
    "D10",
    "DEP",
    "DIO",
    "DMS",
    "DN",
    "DOD",
    "DOX",
    "EDO",
    "EEE",
    "EGL",
    "EOH",
    "EOX",
    "EPE",
    "ETF",
    "FCY",
    "FJO",
    "FLC",
    "FMT",
    "FW5",
    "GOL",
    "GSH",
    "GTT",
    "GYF",
    "HED",
    "IHP",
    "IHS",
    "IMD",
    "IOD",
    "IPA",
    "IPH",
    "LDA",
    "MB3",
    "MEG",
    "MES",
    "MLA",
    "MLI",
    "MOH",
    "MPD",
    "MRD",
    "MSE",
    "MYR",
    "N",
    "NA",
    "NH2",
    "NH4",
    "NHE",
    "NO3",
    "O4B",
    "OHE",
    "OLA",
    "OLC",
    "OMB",
    "OME",
    "OXA",
    "P6G",
    "PE3",
    "PE4",
    "PEG",
    "PEO",
    "PEP",
    "PG0",
    "PG4",
    "PGE",
    "PGR",
    "PLM",
    "PO4",
    "POL",
    "POP",
    "PVO",
    "SAR",
    "SCN",
    "SEO",
    "SEP",
    "SIN",
    "SO4",
    "SPD",
    "SPM",
    "SR",
    "STE",
    "STO",
    "STU",
    "TAR",
    "TBU",
    "TME",
    "TPO",
    "TRS",
    "UNK",
    "UNL",
    "UNX",
    "UPL",
    "URE",
}


####################################################################################################
# TEMPLATES
####################################################################################################

min_coverage_residues = 10
min_coverage_fraction = 0.1


####################################################################################################
# Ambiguous atoms
####################################################################################################

ambiguous_atoms = {
    "CA": {
        "*": "C",
        "OEX": "CA",
        "OEC": "CA",
        "543": "CA",
        "OC6": "CA",
        "OC1": "CA",
        "OC7": "CA",
        "OEY": "CA",
        "OC4": "CA",
        "OC3": "CA",
        "ICA": "CA",
        "CA": "CA",
        "OC2": "CA",
        "OC5": "CA",
    },
    "CD": {"*": "C", "CD": "CD", "CD3": "CD", "CD5": "CD", "CD1": "CD"},
    "BR": "BR",
    "CL": {
        "*": "CL",
        "C8P": "C",
        "L3T": "C",
        "TLC": "C",
        "TZ0": "C",
        "471": "C",
        "NLK": "C",
        "PGM": "C",
        "PNE": "C",
        "RCY": "C",
        "11F": "C",
        "PII": "C",
        "C1Q": "C",
        "4MD": "C",
        "R5A": "C",
        "KW2": "C",
        "I7M": "C",
        "R48": "C",
        "FC3": "C",
        "55V": "C",
        "KPF": "C",
        "SPZ": "C",
        "0TT": "C",
        "R9A": "C",
        "5NA": "C",
        "C55": "C",
        "NIX": "C",
        "5PM": "C",
        "PP8": "C",
        "544": "C",
        "812": "C",
        "NPM": "C",
        "KU8": "C",
        "A1AMM": "C",
        "4S0": "C",
        "AQC": "C",
        "2JK": "C",
        "WJR": "C",
        "A1AAW": "C",
        "85E": "C",
        "MB0": "C",
        "ZAB": "C",
        "85K": "C",
        "GBP": "C",
        "A1H80": "C",
        "A1AFR": "C",
        "L9M": "C",
        "MYK": "C",
        "MB9": "C",
        "38R": "C",
        "EKB": "C",
        "NKF": "C",
        "UMQ": "C",
        "T4K": "C",
        "3PT": "C",
        "A1A7S": "C",
        "1Q9": "C",
        "11R": "C",
        "D2V": "C",
        "SM8": "C",
        "IFC": "C",
        "DB5": "C",
        "L2T": "C",
        "GNB": "C",
        "PP7": "C",
        "072": "C",
        "P88": "C",
        "DRL": "C",
        "C9W": "C",
        "NTP": "C",
        "4HJ": "C",
        "7NA": "C",
        "LPC": "C",
        "T8W": "C",
        "63R": "C",
        "570": "C",
        "R4A": "C",
        "3BG": "C",
        "4RB": "C",
        "GSO": "C",
        "BQ6": "C",
        "R4P": "C",
        "5CP": "C",
        "TTR": "C",
        "6UZ": "C",
        "SPJ": "C",
        "0SA": "C",
        "ZL1": "C",
        "BYG": "C",
        "F0E": "C",
        "PC0": "C",
        "B2Q": "C",
        "KV6": "C",
        "NTO": "C",
        "CLG": "C",
        "R7U": "C",
        "SMQ": "C",
        "GM2": "C",
        "Z7P": "C",
        "NXF": "C",
        "C6Q": "C",
        "A1G": "C",
        "433": "C",
        "L9N": "C",
        "7OX": "C",
        "A1H84": "C",
        "97L": "C",
        "HDV": "C",
        "LUO": "C",
        "R6A": "C",
        "1PC": "C",
        "4PT": "C",
        "SBZ": "C",
        "EAB": "C",
        "FL4": "C",
        "OPS": "C",
        "C2X": "C",
        "SLL": "C",
        "BFC": "C",
        "GIP": "C",
        "7CP": "C",
        "CLH": "C",
        "34E": "C",
        "5NE": "C",
        "PBF": "C",
        "ABD": "C",
        "ABC": "C",
        "LPF": "C",
        "TIZ": "C",
        "4HH": "C",
        "AFC": "C",
        "WQH": "C",
        "9JL": "C",
        "CS3": "C",
        "NL0": "C",
        "KPY": "C",
        "DNA": "C",
        "B3C": "C",
        "TKL": "C",
        "KVS": "C",
        "HO6": "C",
        "NLH": "C",
        "1PB": "C",
        "CYF": "C",
        "G4M": "C",
        "R5B": "C",
        "N4S": "C",
        "N11": "C",
        "C8F": "C",
        "PIJ": "C",
        "WIN": "C",
        "NT1": "C",
        "WJW": "C",
        "HF7": "C",
        "TY1": "C",
        "VM1": "C",
    },
    "OS": {"*": "O", "DWC": "OS", "OHX": "OS", "OS": "OS", "8WV": "OS", "OS4": "OS"},
    "PB": {"*": "P", "ZN9": "PB", "ZN7": "PB", "PBM": "PB", "PB": "PB", "CSB": "PB"},
    "CE": {"*": "C", "CE": "CE"},
    "FE": {"*": "FE", "TFR": "F", "PF5": "F", "IFC": "F", "F5C": "F"},
    "NA": {"*": "N", "CGO": "NA", "R2K": "NA", "LVQ": "NA", "NA": "NA"},
    "ND": {"*": "N", "ND": "ND"},
    "CF": {"*": "C", "CF": "CF"},
    "RU": "RU",
    "BRAF": "BR",
    "EU": "EU",
    "CLAA": "CL",
    "CLBQ": "CL",
    "CM": {"*": "C", "ZCM": "CM"},
    "SN": {"*": "SN", "TAP": "S", "SND": "S", "TAD": "S", "XPT": "S"},
    "AG": "AG",
    "CLN": "CL",
    "CLM": "CL",
    "CLA": {"*": "CL", "PII": "C", "TDL": "C", "D0J": "C", "GM2": "C", "PIJ": "C"},
    "CLB": {
        "*": "CL",
        "TD5": "C",
        "PII": "C",
        "TDL": "C",
        "GM2": "C",
        "TD7": "C",
        "TD6": "C",
        "PIJ": "C",
    },
    "CR": {
        "*": "C",
        "BW9": "CR",
        "CQ4": "CR",
        "AC9": "CR",
        "TIL": "CR",
        "J7U": "CR",
        "CR": "CR",
    },
    "CLAY": "CL",
    "CLBC": "CL",
    "PD": {
        "*": "P",
        "F6Q": "PD",
        "SVP": "PD",
        "SXC": "PD",
        "U5U": "PD",
        "PD": "PD",
        "PLL": "PD",
    },
    "CO": {
        "*": "C",
        "J1S": "CO",
        "OCN": "CO",
        "OL3": "CO",
        "OL4": "CO",
        "B12": "CO",
        "XCO": "CO",
        "UFU": "CO",
        "CON": "CO",
        "OL5": "CO",
        "B13": "CO",
        "7KI": "CO",
        "PL1": "CO",
        "OCO": "CO",
        "J1R": "CO",
        "COH": "CO",
        "SIR": "CO",
        "6KI": "CO",
        "NCO": "CO",
        "9CO": "CO",
        "PC3": "CO",
        "BWU": "CO",
        "B1Z": "CO",
        "J83": "CO",
        "CO": "CO",
        "COY": "CO",
        "CNC": "CO",
        "3CO": "CO",
        "OCL": "CO",
        "R5Q": "CO",
        "X5Z": "CO",
        "CBY": "CO",
        "OLS": "CO",
        "F0X": "CO",
        "I2A": "CO",
        "OCM": "CO",
    },
    "CU": {
        "*": "C",
        "8ZR": "CU",
        "K7E": "CU",
        "CU3": "CU",
        "SI9": "CU",
        "35N": "CU",
        "C2O": "CU",
        "SI7": "CU",
        "B15": "CU",
        "SI0": "CU",
        "CUP": "CU",
        "SQ1": "CU",
        "CUK": "CU",
        "CUL": "CU",
        "SI8": "CU",
        "IC4": "CU",
        "CUM": "CU",
        "MM2": "CU",
        "B30": "CU",
        "S32": "CU",
        "V79": "CU",
        "IMF": "CU",
        "CUN": "CU",
        "MM1": "CU",
        "MP1": "CU",
        "IME": "CU",
        "B17": "CU",
        "C2C": "CU",
        "1CU": "CU",
        "CU6": "CU",
        "C1O": "CU",
        "CU1": "CU",
        "B22": "CU",
        "CUS": "CU",
        "RUQ": "CU",
        "CUF": "CU",
        "CUA": "CU",
        "CU": "CU",
        "CUO": "CU",
        "0TE": "CU",
        "SI4": "CU",
    },
    "CS": {"*": "C", "CS": "CS"},
    "CLQ": "CL",
    "CLR": "CL",
    "CLU": "CL",
    "TE": "TE",
    "NI": {
        "*": "N",
        "USN": "NI",
        "NFO": "NI",
        "NI2": "NI",
        "NFS": "NI",
        "NFR": "NI",
        "82N": "NI",
        "R5N": "NI",
        "NFU": "NI",
        "A1ICD": "NI",
        "NI3": "NI",
        "M43": "NI",
        "MM5": "NI",
        "BF8": "NI",
        "TCN": "NI",
        "NIK": "NI",
        "CUV": "NI",
        "MM6": "NI",
        "J52": "NI",
        "NI": "NI",
        "SNF": "NI",
        "XCC": "NI",
        "F0L": "NI",
        "UWE": "NI",
        "NFC": "NI",
        "3NI": "NI",
        "HNI": "NI",
        "F43": "NI",
        "RQM": "NI",
        "NFE": "NI",
        "NFB": "NI",
        "B51": "NI",
        "NI1": "NI",
        "WCC": "NI",
        "NUF": "NI",
    },
    "SB": {"*": "S", "UJI": "SB", "SB": "SB", "118": "SB", "SBO": "SB", "3CG": "SB"},
    "MO": "MO",
    "SEG": "SE",
    "CLL": "CL",
    "CLAH": "CL",
    "CLC": {
        "*": "CL",
        "TD5": "C",
        "PII": "C",
        "TDL": "C",
        "GM2": "C",
        "TD7": "C",
        "TD6": "C",
        "PIJ": "C",
    },
    "CLD": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
    "CLAD": "CL",
    "CLAE": "CL",
    "LA": "LA",
    "RH": "RH",
    "BRAC": "BR",
    "BRAD": "BR",
    "CLBN": "CL",
    "CLAC": "CL",
    "BRAB": "BR",
    "BRAE": "BR",
    "MG": "MG",
    "IR": "IR",
    "SE": {
        "*": "SE",
        "HII": "S",
        "NT2": "S",
        "R2P": "S",
        "S2P": "S",
        "0IU": "S",
        "QMB": "S",
        "81S": "S",
        "0QB": "S",
        "UB4": "S",
        "OHS": "S",
        "Q78": "S",
        "0Y2": "S",
        "B3M": "S",
        "NT1": "S",
        "81R": "S",
    },
    "BRAG": "BR",
    "CLF": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
    "CLE": {"*": "CL", "PII": "C", "GM2": "C", "PIJ": "C"},
    "BRAX": "BR",
    "CLK": "CL",
    "ZN": "ZN",
    "AS": "AS",
    "AU": "AU",
    "PT": "PT",
    "CLAS": "CL",
    "MN": "MN",
    "CLBE": "CL",
    "CLBF": "CL",
    "CLAF": "CL",
    "NA'": {"*": "N", "CGO": "NA"},
    "BRAH": "BR",
    "BRAI": "BR",
    "BRA": "BR",
    "BRB": "BR",
    "BRAV": "BR",
    "HG": {
        "*": "HG",
        "BBA": "H",
        "MID": "H",
        "APM": "H",
        "4QQ": "H",
        "0ZG": "H",
        "APH": "H",
    },
    "AR": "AR",
    "D": "H",
    "CLAN": "CL",
    "SI": "SI",
    "CLS": "CL",
    "ZR": "ZR",
    "CLAR": {"*": "CL", "ZM4": "C"},
    "HO": "HO",
    "CLI": {"*": "CL", "GM2": "C"},
    "CLH": {"*": "CL", "GM2": "C"},
    "CLAP": "CL",
    "CLBL": "CL",
    "CLBM": "CL",
    "PR": {"*": "PR", "UF0": "P", "252": "P"},
    "IN": "IN",
    "CLJ": "CL",
    "BRU": "BR",
    "SC": {"*": "S", "SFL": "SC"},
    "CLG": {"*": "CL", "GM2": "C"},
    "BRAT": "BR",
    "BRAR": "BR",
    "CLAG": "CL",
    "CLAB": "CL",
    "CLV": "CL",
    "TI": "TI",
    "CLAX": "CL",
    "CLAJ": "CL",
    "CL'": {"*": "CL", "BNR": "C", "25A": "C", "BDA": "C"},
    "CLAW": "CL",
    "BRF": "BR",
    "BRE": "BR",
    "RE": "RE",
    "GD": "GD",
    "SM": {"*": "S", "SM": "SM"},
    "CLBH": "CL",
    "CLBI": "CL",
    "CLAI": "CL",
    "CLY": "CL",
    "CLZ": "CL",
    "AC": "AC",
    "BR'": "BR",
    "CLT": "CL",
    "CLO": "CL",
    "CLP": "CL",
    "LU": "LU",
    "BA": {"*": "B", "BA": "BA"},
    "CLAU": "CL",
    "RB": "RB",
    "LI": "LI",
    "MOM": "MO",
    "BRAQ": "BR",
    "SR": {"*": "S", "SR": "SR", "OER": "SR"},
    "CLAT": "CL",
    "BRAL": "BR",
    "SEB": "SE",
    "CLW": "CL",
    "CLX": "CL",
    "BE": "BE",
    "BRG": "BR",
    "SEA": "SE",
    "BRAW": "BR",
    "BRBB": "BR",
    "ER": "ER",
    "TH": "TH",
    "BRR": "BR",
    "CLBV": "CL",
    "AL": "AL",
    "CLAV": "CL",
    "BRH": "BR",
    "CLAQ": "CL",
    "GA": "GA",
    "X": "*",
    "TL": "TL",
    "CLBB": "CL",
    "TB": "TB",
    "CLAK": "CL",
    "XE": {"*": "*", "XE": "XE"},
    "SEL": "SE",
    "PU": {"*": "P", "4PU": "PU"},
    "CLAZ": "CL",
    "SE'": "SE",
    "CLBA": "CL",
    "SEN": "SE",
    "SNN": "SN",
    "MOB": "MO",
    "YB": "YB",
    "BRC": "BR",
    "BRD": "BR",
    "CLAM": "CL",
    "DA": "H",
    "DB": "H",
    "DC": "H",
    "DXT": "H",
    "DXU": "H",
    "DXX": "H",
    "DXY": "H",
    "DXZ": "H",
    "DY": "DY",
    "TA": "TA",
    "XD": "*",
    "SED": "SE",
    "CLAL": "CL",
    "BRAJ": "BR",
    "AM": "AM",
    "CLAO": "CL",
    "BI": "BI",
    "KR": "KR",
    "BRBJ": "BR",
    "UNK": "*",
}


####################################################################################################
# ESM
####################################################################################################

aa_mapping = {
    "ALA": "A",
    "ARG": "R",
    "ASN": "N",
    "ASP": "D",
    "CYS": "C",
    "GLN": "Q",
    "GLU": "E",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LEU": "L",
    "LYS": "K",
    "MET": "M",
    "PHE": "F",
    "PRO": "P",
    "SER": "S",
    "THR": "T",
    "TRP": "W",
    "TYR": "Y",
    "VAL": "V",
    "ACE": "X",
    "SCH": "X",
    "UNK": "X",  # Handle non-standard residues
}

####################################################################################################
# MSAPF
####################################################################################################

aa2tok_d_28 = {
    "A": 0,  # ALA
    "R": 1,  # ARG
    "N": 2,  # ASN
    "D": 3,  # ASP
    "C": 4,  # CYS
    "E": 5,  # GLU
    "Q": 6,  # GLN
    "G": 7,  # GLY
    "H": 8,  # HIS
    "I": 9,  # ILE
    "L": 10,  # LEU
    "K": 11,  # LYS
    "M": 12,  # MET
    "F": 13,  # PHE
    "P": 14,  # PRO
    "S": 15,  # SER
    "T": 16,  # THR
    "W": 17,  # TRP
    "Y": 18,  # TYR
    "V": 19,  # VAL
    "X": 20,  # UNK
    "B": 21,  # ASP or ASN
    "Z": 22,  # GLU or GLN
    "U": 23,  # SEC
    "O": 24,  # PYL
    "-": 25,  # GAP
    "<pad>": 26,  # Padded positions
    "<mask>": 27,  # Mask
}

tok2aa_d_28 = {v: k for k, v in aa2tok_d_28.items()}

# Create mapping from 33-token indices to 28-token indices
three_to_single = {
    "ALA": "A",
    "ARG": "R",
    "ASN": "N",
    "ASP": "D",
    "CYS": "C",
    "GLU": "E",
    "GLN": "Q",
    "GLY": "G",
    "HIS": "H",
    "ILE": "I",
    "LEU": "L",
    "LYS": "K",
    "MET": "M",
    "PHE": "F",
    "PRO": "P",
    "SER": "S",
    "THR": "T",
    "TRP": "W",
    "TYR": "Y",
    "VAL": "V",
}

import torch

# Create mapping tensor from 33 tokens to 28 tokens
mapping_33_to_28 = torch.full(
    (33,), aa2tok_d_28["X"], dtype=torch.long
)  # Default to UNK

for i, token_33 in enumerate(tokens):
    if token_33 in three_to_single:
        # Map 3-letter amino acids to single letters
        single_letter = three_to_single[token_33]
        mapping_33_to_28[i] = aa2tok_d_28[single_letter]
    elif token_33 == "UNK":
        mapping_33_to_28[i] = aa2tok_d_28["X"]
    elif token_33 == "<pad>":
        mapping_33_to_28[i] = aa2tok_d_28["<pad>"]
    elif token_33 == "-":
        mapping_33_to_28[i] = aa2tok_d_28["-"]
    else:
        # RNA/DNA tokens and others map to UNK
        mapping_33_to_28[i] = aa2tok_d_28["X"]
