from typing import Mapping, Sequence
import numpy as np
from entity import (
    residue_constants as rc, molecule_constants as mc
)

entity_types = [
    "protein",
    "molecule"
]

entity_type_num = len(entity_types)
entity_type_order = {entity_type: i for i, entity_type in enumerate(entity_types)}

# token = residue + atoms
token_types = rc.resnames + mc.token_names
token_entity_type = [0] * len(rc.resnames) + [1] * len(mc.token_names)

# remember to use this when dealing with 
token_type_num = len(token_types)
token_type_order = {token_type: i for i, token_type in enumerate(token_types)}


# atom types = atom type in protein + atom type in molecule
atom_types = rc.atom_types + mc.atom_types
atom_type_num = len(atom_types)
atom_order = {atom_type: i for i, atom_type in enumerate(atom_types)}

# how many extra feats (after one-hot coding)
extra_feat_num = mc.extra_feat_num

# how many pair feats
pair_feat_num = mc.pair_feat_num

edge_types = ["lig_bond", "prot_adjacent"]
edge_type_num = len(edge_types)
edge_type_order = {edge_type: i for i, edge_type in enumerate(edge_types)}

toktype_to_atom14_names = rc.restype_name_to_atom14_names | {
    mol_tok: [mc.token_name_to_atom_type[mol_tok]] + [""] * 13 for mol_tok in mc.token_names
}
token_atoms = rc.residue_atoms | {
    "UNK": []
} | {
    mol_tok: [mc.token_name_to_atom_type[mol_tok]] for mol_tok in mc.token_names
}

token_atom_renaming_swaps = rc.residue_atom_renaming_swaps

chi_angles_mask = rc.chi_angles_mask + [
    [0.0, 0.0, 0.0, 0.0] # for UNK
] + mc.chi_angles_mask

chi_angles_atoms = rc.chi_angles_atoms | {
    "UNK": [] # for UNK
} | mc.chi_angles_atoms

entity_default_rigid_groups = {
    "protein": {
        0: ["C", "CA", "N"],
        2: ["CA", "C", "O"],
    },
    "molecule": {
        0: [mc.mol_atom_name],
    }
}

chi_pi_periodic = rc.chi_pi_periodic + mc.chi_pi_periodic

rigid_group_atom_positions = rc.rigid_group_atom_positions | {
    "UNK": []
} | mc.rigid_group_atom_positions

# create an array with (restype, atomtype) --> rigid_group_idx
# and an array with (restype, atomtype, coord) for the atom positions
# and compute affine transformation matrices (4,4) from one rigid group to the
# previous group
toktype_atomFull_to_rigid_group = np.zeros([token_type_num, atom_type_num], dtype=int)
toktype_atomFull_mask = np.zeros([token_type_num, atom_type_num], dtype=np.float32)
toktype_atomFull_rigid_group_positions = np.zeros([token_type_num, atom_type_num, 3], dtype=np.float32)
toktype_atom14_to_rigid_group = np.zeros([token_type_num, 14], dtype=int)
toktype_atom14_mask = np.zeros([token_type_num, 14], dtype=np.float32)
toktype_atom14_rigid_group_positions = np.zeros([token_type_num, 14, 3], dtype=np.float32)
toktype_rigid_group_default_frame = np.zeros([token_type_num, 8, 4, 4], dtype=np.float32)

van_der_waals_radius = rc.van_der_waals_radius | mc.van_der_waals_radius

def _make_rigid_transformation_4x4(ex, ey, translation):
    """Create a rigid 4x4 transformation matrix from two axes and transl."""
    # Normalize ex.
    ex_normalized = ex / np.linalg.norm(ex)

    # make ey perpendicular to ex
    ey_normalized = ey - np.dot(ey, ex_normalized) * ex_normalized
    ey_normalized /= np.linalg.norm(ey_normalized)

    # compute ez as cross product
    eznorm = np.cross(ex_normalized, ey_normalized)
    m = np.stack(
        [ex_normalized, ey_normalized, eznorm, translation]
    ).transpose()
    m = np.concatenate([m, [[0.0, 0.0, 0.0, 1.0]]], axis=0)
    return m


def _make_rigid_group_constants():
    """Fill the arrays above."""
    # idx and toktype
    # restype -> toktype_idx
    # resname -> toktype
    for toktype_idx, toktype in enumerate(token_types):
        for atomname, group_idx, atom_position in rigid_group_atom_positions[
            toktype
        ]:
            atomtype = atom_order[atomname]
            toktype_atomFull_to_rigid_group[toktype_idx, atomtype] = group_idx
            toktype_atomFull_mask[toktype_idx, atomtype] = 1
            toktype_atomFull_rigid_group_positions[
                toktype_idx, atomtype, :
            ] = atom_position

            atom14idx = toktype_to_atom14_names[toktype].index(atomname)
            toktype_atom14_to_rigid_group[toktype_idx, atom14idx] = group_idx
            toktype_atom14_mask[toktype_idx, atom14idx] = 1
            toktype_atom14_rigid_group_positions[
                toktype_idx, atom14idx, :
            ] = atom_position

    for toktype_idx, toktype in enumerate(token_types):
        atom_positions = {
            name: np.array(pos)
            for name, _, pos in rigid_group_atom_positions[toktype]
        }

        # backbone to backbone is the identity transform
        toktype_rigid_group_default_frame[toktype_idx, 0, :, :] = np.eye(4)

        # pre-omega-frame to backbone (currently dummy identity matrix)
        toktype_rigid_group_default_frame[toktype_idx, 1, :, :] = np.eye(4)

        if "N" in atom_positions and "CA" in atom_positions:
            # phi-frame to backbone
            mat = _make_rigid_transformation_4x4(
                ex=atom_positions["N"] - atom_positions["CA"],
                ey=np.array([1.0, 0.0, 0.0]),
                translation=atom_positions["N"],
            )
            toktype_rigid_group_default_frame[toktype_idx, 2, :, :] = mat

            if "C" in atom_positions:
                # psi-frame to backbone
                mat = _make_rigid_transformation_4x4(
                    ex=atom_positions["C"] - atom_positions["CA"],
                    ey=atom_positions["CA"] - atom_positions["N"],
                    translation=atom_positions["C"],
                )
                toktype_rigid_group_default_frame[toktype_idx, 3, :, :] = mat
            else:
                toktype_rigid_group_default_frame[toktype_idx, 3, :, :] = np.eye(4)
        else:
            toktype_rigid_group_default_frame[toktype_idx, 2, :, :] = np.eye(4)


        # chi1-frame to backbone
        if chi_angles_mask[toktype_idx][0]:
            base_atom_names = chi_angles_atoms[toktype][0]
            base_atom_positions = [
                atom_positions[name] for name in base_atom_names
            ]
            mat = _make_rigid_transformation_4x4(
                ex=base_atom_positions[2] - base_atom_positions[1],
                ey=base_atom_positions[0] - base_atom_positions[1],
                translation=base_atom_positions[2],
            )
            toktype_rigid_group_default_frame[toktype_idx, 4, :, :] = mat

        # chi2-frame to chi1-frame
        # chi3-frame to chi2-frame
        # chi4-frame to chi3-frame
        # luckily all rotation axes for the next frame start at (0,0,0) of the
        # previous frame
        for chi_idx in range(1, 4):
            if chi_angles_mask[toktype_idx][chi_idx]:
                axis_end_atom_name = chi_angles_atoms[toktype][chi_idx][2]
                axis_end_atom_position = atom_positions[axis_end_atom_name]
                mat = _make_rigid_transformation_4x4(
                    ex=axis_end_atom_position,
                    ey=np.array([-1.0, 0.0, 0.0]),
                    translation=axis_end_atom_position,
                )
                toktype_rigid_group_default_frame[
                    toktype_idx, 4 + chi_idx, :, :
                ] = mat


_make_rigid_group_constants()

def token_seq_to_onehot(
    sequence: Sequence[str],
    mapping: Mapping[str, int],
    map_unknown_to: str
) -> np.ndarray:
    """Maps the given array of token type array into a one-hot encoded matrix.

    Args:
      sequence: An token sequence.
      mapping: A dictionary mapping token type to integers.
      map_unknown_to: If not None, unknown token will be mapped to it

    Returns:
      A numpy array of shape (seq_len, num_unique_aas) with one-hot encoding of
      the sequence.

    Raises:
      ValueError: If the mapping doesn't contain values from 0 to
        num_unique_aas - 1 without any gaps.
    """
    num_entries = max(mapping.values()) + 1

    if sorted(set(mapping.values())) != list(range(num_entries)):
        raise ValueError(
            "The mapping must have values from 0 to num_unique_aas-1 "
            "without any gaps. Got: %s" % sorted(mapping.values())
        )

    one_hot_arr = np.zeros((len(sequence), num_entries), dtype=np.int32)

    for tok_index, tok_type in enumerate(sequence):
        if (not map_unknown_to) and (tok_type not in mapping):
            raise ValueError(
                f"Invalid token in the sequence: {tok_type}"
            )
        if map_unknown_to:
            tok_id = mapping.get(tok_type, mapping[map_unknown_to])
        else:
            tok_id = mapping[tok_type]
        one_hot_arr[tok_index, tok_id] = 1

    return one_hot_arr