import torch
import torch.nn.functional as F
import numpy as np

from core.datasets.pl_data import ProteinLigandData
from core.datasets import utils as utils_data

AROMATIC_FEAT_MAP_IDX = utils_data.ATOM_FAMILIES_ID['Aromatic']

# only atomic number 1, 6, 7, 8, 9, 15, 16, 17 exist
MAP_ATOM_TYPE_FULL_TO_INDEX = {
    (1, 'S', False): 0,
    (6, 'SP', False): 1,
    (6, 'SP2', False): 2,
    (6, 'SP2', True): 3,
    (6, 'SP3', False): 4,
    (7, 'SP', False): 5,
    (7, 'SP2', False): 6,
    (7, 'SP2', True): 7,
    (7, 'SP3', False): 8,
    (8, 'SP2', False): 9,
    (8, 'SP2', True): 10,
    (8, 'SP3', False): 11,
    (9, 'SP3', False): 12,
    (15, 'SP2', False): 13,
    (15, 'SP2', True): 14,
    (15, 'SP3', False): 15,
    (15, 'SP3D', False): 16,
    (15, 'SP3D2', False): 17,
    (16, 'SP2', False): 18,
    (16, 'SP2', True): 19,
    (16, 'SP3', False): 20,
    (16, 'SP3D', False): 21,
    (16, 'SP3D2', False): 22,
    (17, 'SP3', False): 23,
}

MAP_ATOM_TYPE_ONLY_TO_INDEX = {
    1: 0,
    6: 1,
    7: 2,
    8: 3,
    9: 4,
    15: 5,
    16: 6,
    17: 7,
}

MAP_ATOM_TYPE_AROMATIC_TO_INDEX = {
    (1, False): 0,
    (6, False): 1,
    (6, True): 2,
    (7, False): 3,
    (7, True): 4,
    (8, False): 5,
    (8, True): 6,
    (9, False): 7,
    (15, False): 8,
    (15, True): 9,
    (16, False): 10,
    (16, True): 11,
    (17, False): 12,
}

# pdb   aromatic_counter {(6, False): 803564, (8, False): 264143, (7, False): 198073, (6, True): 1371671, (7, True): 209316, (8, True): 8567, (9, False): 72259, (16, False): 21263, (17, False): 26781, (16, True): 11529, (35, False): 3359, (15, False): 2051, (1, False): 998, (53, False): 852, (14, False): 92, (5, False): 57} 
# pdb   full_counter {(6, 'SP3', False): 672520, (6, 'SP2', False): 120223, (8, 'SP2', False): 218819, (6, 'SP', False): 10821, (7, 'SP', False): 6030, (6, 'SP2', True): 1371671, (7, 'SP2', False): 149454, (7, 'SP2', True): 209316, (8, 'SP3', False): 45324, (8, 'SP2', True): 8567, (9, 'SP3', False): 72259, (16, 'SP2', False): 496, (16, 'SP3', False): 20745, (17, 'SP3', False): 26781, (16, 'SP2', True): 11529, (35, 'SP3', False): 3359, (7, 'SP3', False): 42589, (53, 'SP3', False): 852, (15, 'SP3', False): 2051, (5, 'SP3', False): 3, (5, 'SP2', False): 54, (1, 'S', False): 998, (14, 'SP3', False): 92, (16, 'SP3D2', False): 22}

MAP_ATOM_TYPE_AROMATIC_PDB_TO_INDEX = {
    (1, False): 0,  # 998
    (6, False): 1,  # 803564
    (6, True): 2,   # 1371671
    (7, False): 3,  # 198073
    (7, True): 4,   # 209316
    (8, False): 5,  # 264143
    (8, True): 6,   # 8567
    (9, False): 7,  # 72259 
    (15, False): 8, # 2051
    (16, False): 9, # 21263
    (16, True): 10, # 11529
    (17, False): 11, # 26781
    (35, False): 12, # 3359
}

MAP_ATOM_TYPE_AROMATIC_PDB_ALL_TO_INDEX = {
    (1, False): 0,  # 998
    (6, False): 1,  # 803564
    (6, True): 2,   # 1371671
    (7, False): 3,  # 198073
    (7, True): 4,   # 209316
    (8, False): 5,  # 264143
    (8, True): 6,   # 8567
    (9, False): 7,  # 72259
    (14, False): 8, # 92 -------
    (15, False): 9, # 2051
    (16, False): 10, # 21263
    (16, True): 11, # 11529
    (17, False): 12, # 26781
    (35, False): 13, # 3359
    (53, False): 14, # 852 -----
}

# MAP_ATOM_TYPE_AROMATIC_FULL_TO_INDEX = {
#     (1, False): 0, # 0
#     (3, False): 1, # 5
#     (5, False): 2, # 333
#     (6, False): 3, # 1233508
#     (6, True): 4, # 1435006
#     (7, False): 5, # 262989
#     (7, True): 6, # 207743
#     (8, False): 7, # 622803
#     (8, True): 8, # 8984
#     (9, False): 9, # 52635
#     (12, False): 10, # 45
#     (13, False): 11, # 4
#     (14, False): 12, # 92
#     (15, False): 13, # 40252
#     (15, True): 14, # 1
#     (16, False): 15, # 28827
#     (16, True): 16, # 12789
#     (17, False): 17, # 25978
#     (21, False): 18, # 1
#     (23, False): 19, # 51
#     (24, False): 20, # 5
#     (26, False): 21, # 40
#     (33, False): 22, # 2
#     (34, False): 23, # 48
#     (34, True): 24, # 7
#     (35, False): 25, # 7524
#     (42, False): 26, # 80
#     (44, False): 27, # 53
#     (50, False): 28, # 12
#     (53, False): 29, # 1607
#     (74, False): 30, # 18
#     (79, False): 31, # 4
#     (80, False): 32, # 1
# }

MAP_INDEX_TO_ATOM_TYPE_ONLY = {v: k for k, v in MAP_ATOM_TYPE_ONLY_TO_INDEX.items()}
MAP_INDEX_TO_ATOM_TYPE_AROMATIC = {v: k for k, v in MAP_ATOM_TYPE_AROMATIC_TO_INDEX.items()}
MAP_INDEX_TO_ATOM_TYPE_FULL = {v: k for k, v in MAP_ATOM_TYPE_FULL_TO_INDEX.items()}
MAP_INDEX_TO_ATOM_TYPE_AROMATIC_PDB = {v: k for k, v in MAP_ATOM_TYPE_AROMATIC_PDB_TO_INDEX.items()}
# MAP_INDEX_TO_ATOM_TYPE_AROMATIC_FULL = {v: k for k, v in MAP_ATOM_TYPE_AROMATIC_FULL_TO_INDEX.items()}

def get_atomic_number_from_index(index, mode):
    if mode == 'basic':
        atomic_number = [MAP_INDEX_TO_ATOM_TYPE_ONLY[i] for i in index.tolist()]
    elif mode == 'add_aromatic':
        atomic_number = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC[i][0] for i in index.tolist()]
    elif mode == 'add_aromatic_PDB':
        atomic_number = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC_PDB[i][0] for i in index.tolist()]
    # elif mode == 'add_aromatic_full':
    #     atomic_number = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC_FULL[i][0] for i in index.tolist()]
    elif mode == 'full':
        atomic_number = [MAP_INDEX_TO_ATOM_TYPE_FULL[i][0] for i in index.tolist()]
    else:
        raise ValueError
    return atomic_number


def is_aromatic_from_index(index, mode):
    if mode == 'add_aromatic':
        is_aromatic = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC[i][1] for i in index.tolist()]
    elif mode == 'add_aromatic_PDB':
        is_aromatic = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC_PDB[i][1] for i in index.tolist()]
    # elif mode == 'add_aromatic_full':
    #     is_aromatic = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC_FULL[i][1] for i in index.tolist()]
    elif mode == 'full':
        is_aromatic = [MAP_INDEX_TO_ATOM_TYPE_FULL[i][2] for i in index.tolist()]
    elif mode == 'basic':
        is_aromatic = None
    else:
        raise ValueError
    return is_aromatic


def get_hybridization_from_index(index, mode):
    if mode == 'full':
        hybridization = [MAP_INDEX_TO_ATOM_TYPE_AROMATIC[i][1] for i in index.tolist()]
    else:
        raise ValueError
    return hybridization


def get_index(atom_num, hybridization, is_aromatic, mode):
    if mode == 'basic':
        return MAP_ATOM_TYPE_ONLY_TO_INDEX[int(atom_num)]
    elif mode == 'add_aromatic':
        # self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17])  # H, C, N, O, F, P, S, Cl
        if (int(atom_num), bool(is_aromatic)) in MAP_ATOM_TYPE_AROMATIC_TO_INDEX:
            return MAP_ATOM_TYPE_AROMATIC_TO_INDEX[int(atom_num), bool(is_aromatic)]
        else:
            print(int(atom_num), bool(is_aromatic))
            return MAP_ATOM_TYPE_AROMATIC_TO_INDEX[(1, False)]
    elif mode == 'add_aromatic_PDB':
        # self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 9, 15, 16, 17, 35])  # H, C, N, O, F, P, S, Cl, Br
        if (int(atom_num), bool(is_aromatic)) in MAP_ATOM_TYPE_AROMATIC_PDB_TO_INDEX:
            return MAP_ATOM_TYPE_AROMATIC_PDB_TO_INDEX[int(atom_num), bool(is_aromatic)]
        else:
            print(int(atom_num), bool(is_aromatic), 'to', 35, False)
            raise ValueError
    # elif mode == 'add_aromatic_full':
        # return MAP_ATOM_TYPE_AROMATIC_FULL_TO_INDEX[int(atom_num), bool(is_aromatic)]
    elif mode == 'full':
        if (int(atom_num), str(hybridization), bool(is_aromatic)) in MAP_ATOM_TYPE_FULL_TO_INDEX:
            return MAP_ATOM_TYPE_FULL_TO_INDEX[(int(atom_num), str(hybridization), bool(is_aromatic))]
        else:
            print(int(atom_num), str(hybridization), bool(is_aromatic))
            return MAP_ATOM_TYPE_FULL_TO_INDEX[(1, 'S', False)]
    else:
        raise NotImplementedError


class FeaturizeProteinAtom(object):

    def __init__(self):
        super().__init__()
        self.atomic_numbers = torch.LongTensor([1, 6, 7, 8, 16, 34])  # H, C, N, O, S, Se
        self.max_num_aa = 20

        # ###
        # self.atom_names = [
        #     'CA', 'C', 'N', 'O', 'H', 'CG', 'SD', 'NH2', 'OH', 'OG', 'CZ3', 'CD', 'CH2', 'NH1', 'CD1', 'ND2', 'OD1', 
        #     'OD2', 'OE1', 'CD2', 'CZ', 'OG1', 'CG1', 'OE2', 'CE3', 'CE', 'NE1', 'CE2', 'CE1', 
        #     'CZ2', 'NZ', 'NE', 'SG', 'ND1', 'CG2', 'NE2', 'CB'
        # ]
        # self.atom_name_2_idx = {name:idx for idx,name in enumerate(self.atom_names)}

    @property
    def feature_dim(self):
        return self.atomic_numbers.size(0) + self.max_num_aa + 1

    def __call__(self, data: ProteinLigandData):
        element = data.protein_element.view(-1, 1) == self.atomic_numbers.view(1, -1)  # (N_atoms, N_elements)  one_hot
        amino_acid = F.one_hot(data.protein_atom_to_aa_type, num_classes=self.max_num_aa)
        is_backbone = data.protein_is_backbone.view(-1, 1).long()

        # TODO: is_backbone is 0/1 values, not sure the feature is treated as categorical, if so, change to 2-hot
        x = torch.cat([element, amino_acid, is_backbone], dim=-1)
        data.protein_atom_feature = x


        # select_mask = [False] * data.protein_pos.shape[0]
        # for idx in range(data.protein_pos.shape[0]):
        #     # N,O
        #     if data.protein_element[idx] in [7,8] \
        #         or (data.protein_atom_to_aa_type[idx] == 4 and data.protein_atom_name[idx] in ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ']) \
        #         or (data.protein_atom_to_aa_type[idx] == 6  and data.protein_atom_name[idx] in ['CG', 'ND1', 'CD2', 'CE1', 'NE2']) \
        #         or (data.protein_atom_to_aa_type[idx] == 18  and data.protein_atom_name[idx] in ['CG', 'CD1', 'CD2', 'NE1', 'CE2', 'CE3', 'CZ2', 'CZ3', 'CH2']) \
        #         or (data.protein_atom_to_aa_type[idx] == 19  and data.protein_atom_name[idx] in ['CG', 'CD1', 'CD2', 'CE1', 'CE2', 'CZ']):  # 4 for Phe, 6 for His, 18 for Trp, 19 for Tyr
        #         select_mask[idx] = True

        # is_N_O = torch.tensor((data.protein_element == 7) | (data.protein_element == 8)).long()
        # is_aromatic = torch.tensor(select_mask, dtype=torch.long) & ~(is_N_O)

        # x = torch.cat([element, amino_acid, is_backbone, is_N_O.view(-1, 1), is_aromatic.view(-1, 1)], dim=-1)
        # data.protein_atom_feature = x



        # data.protein_pos = data.protein_pos[select_mask]
        # data.protein_atom_feature = data.protein_atom_feature[select_mask]
        # data.protein_element = data.protein_element[select_mask]


        # data.protein_pos = torch.concat([data.protein_pos, data.protein_pos[select_mask]], dim=0)
        # data.protein_atom_feature = torch.concat([data.protein_atom_feature, data.protein_atom_feature[select_mask]], dim=0)
        # data.protein_element = torch.concat([data.protein_element, data.protein_element[select_mask]], dim=0)


        # print(data.protein_filename)
        # for k,v in data.items():
        #     if k in ['protein_pos', 'protein_atom_feature', 'protein_element']:
        #         print(k,v.shape)
        # raise ValueError


        return data


class FeaturizeLigandAtom(object):

    def __init__(self, mode='basic'):
        super().__init__()
        assert mode in ['basic', 'add_aromatic', 'full', 'add_aromatic_full', 'add_aromatic_PDB']
        self.mode = mode

    @property
    def feature_dim(self):
        if self.mode == 'basic':
            return len(MAP_ATOM_TYPE_ONLY_TO_INDEX)
        elif self.mode == 'add_aromatic':
            return len(MAP_ATOM_TYPE_AROMATIC_TO_INDEX)
        elif self.mode == 'add_aromatic_PDB':
            return len(MAP_ATOM_TYPE_AROMATIC_PDB_TO_INDEX)
        # elif self.mode == 'add_aromatic_full':
        #     return len(MAP_ATOM_TYPE_AROMATIC_FULL_TO_INDEX)
        elif self.mode == 'full':
            return len(MAP_ATOM_TYPE_FULL_TO_INDEX)
        else:
            raise NotImplementedError

    def __call__(self, data: ProteinLigandData):
        element_list = data.ligand_element
        hybridization_list = data.ligand_hybridization
        aromatic_list = [v[AROMATIC_FEAT_MAP_IDX] for v in data.ligand_atom_feature]

        x = [get_index(e, h, a, self.mode) for e, h, a in zip(element_list, hybridization_list, aromatic_list)]
        x = torch.tensor(x)
        data.ligand_atom_feature_full = x

        if hasattr(data, 'fix_index'):
            data.gen_flag_lig = [0.0 if atom_idx in data.fix_index else 1.0 for atom_idx in range(x.shape[0]) ]
            data.gen_flag_lig = torch.tensor(data.gen_flag_lig)
        return data


class FeaturizeLigandBond(object):

    def __init__(self, mode='fc', set_bond_type=True):
        super().__init__()
        self.mode = mode
        self.set_bond_type = set_bond_type

    def __call__(self, data: ProteinLigandData):
        n_atoms = len(data.ligand_element)  # only ligand atom mask is reset in beta prior sampling
        full_dst = torch.repeat_interleave(torch.arange(n_atoms), n_atoms)
        full_src = torch.arange(n_atoms).repeat(n_atoms)
        mask = full_dst != full_src
        full_dst, full_src = full_dst[mask], full_src[mask]
        data.ligand_fc_bond_index = torch.stack([full_src, full_dst], dim=0)
        assert data.ligand_fc_bond_index.size(0) == 2


        if hasattr(data, 'ligand_bond_index') and self.set_bond_type:
            n_atoms = len(data.ligand_element)
            bond_matrix = torch.zeros(n_atoms, n_atoms).long()
            src, dst = data.ligand_bond_index
            bond_matrix[src, dst] = data.ligand_bond_type
            # assert data.ligand_bond_type.max() < 5, data.ligand_bond_type.max()
            if self.mode == 'divide':
                bond_matrix = (bond_matrix.float() / 2).ceil().long() # 0, 1, 2, 3, 4, 5 -> 0, 1, 1, 2, 2, 3
            data.ligand_fc_bond_type = bond_matrix[data.ligand_fc_bond_index[0], data.ligand_fc_bond_index[1]]
        return data


class RandomRotation(object):

    def __init__(self):
        super().__init__()

    def __call__(self,  data: ProteinLigandData):
        M = np.random.randn(3, 3)
        Q, __ = np.linalg.qr(M)
        Q = torch.from_numpy(Q.astype(np.float32))
        data.ligand_pos = data.ligand_pos @ Q
        data.protein_pos = data.protein_pos @ Q
        return data
