
import os
import warnings
from collections import Counter
import torch
import numpy as np
import pandas as pd
from scipy.cluster import hierarchy
from data_prep import MasterDataset, load_hdf5, get_data, split_data, similarity_vectors
warnings.simplefilter(action='ignore', category=FutureWarning)
from rdkit import Chem
from tqdm import tqdm





allowable_features = {
    'possible_atomic_num_list':       list(range(1, 119)),
    'possible_formal_charge_list':    [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5],
    'possible_chirality_list':        [
        Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
        Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
        Chem.rdchem.ChiralType.CHI_OTHER
    ],
    'possible_hybridization_list':    [
        Chem.rdchem.HybridizationType.S,
        Chem.rdchem.HybridizationType.SP,
        Chem.rdchem.HybridizationType.SP2,
        Chem.rdchem.HybridizationType.SP3,
        Chem.rdchem.HybridizationType.SP3D,
        Chem.rdchem.HybridizationType.SP3D2,
        Chem.rdchem.HybridizationType.UNSPECIFIED
    ],
    'possible_numH_list':             [0, 1, 2, 3, 4, 5, 6, 7, 8],
    'possible_implicit_valence_list': [0, 1, 2, 3, 4, 5, 6],
    'possible_degree_list':           [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
    'possible_bonds':                 [
        Chem.rdchem.BondType.SINGLE,
        Chem.rdchem.BondType.DOUBLE,
        Chem.rdchem.BondType.TRIPLE,
        Chem.rdchem.BondType.AROMATIC
    ],
    'possible_bond_dirs':             [  # only for double bond stereo information
        Chem.rdchem.BondDir.NONE,
        Chem.rdchem.BondDir.ENDUPRIGHT,
        Chem.rdchem.BondDir.ENDDOWNRIGHT
    ]
}


def mol_to_graph_data_obj_simple_3D(mol):
    """
    Converts rdkit mol object to graph Data object required by the pytorch
    geometric package. NB: Uses simplified atom and bond features, and represent as indices
    :param mol: rdkit mol object
    return: graph data object with the attributes: x, edge_index, edge_attr """

    # todo: more atom/bond features in the future
    # atoms, two features: atom type, chirality tag
    atom_features_list = []
    for atom in mol.GetAtoms():
        atom_feature = [allowable_features['possible_atomic_num_list'].index(atom.GetAtomicNum())] + \
                       [allowable_features['possible_chirality_list'].index(atom.GetChiralTag())]
        atom_features_list.append(atom_feature)
    x = torch.tensor(np.array(atom_features_list), dtype=torch.long)

    # bonds, two features: bond type, bond direction
    if len(mol.GetBonds()) > 0:  # mol has bonds
        edges_list = []
        edge_features_list = []
        for bond in mol.GetBonds():
            i = bond.GetBeginAtomIdx()
            j = bond.GetEndAtomIdx()
            edge_feature = [allowable_features['possible_bonds'].index(bond.GetBondType())] + \
                           [allowable_features['possible_bond_dirs'].index(bond.GetBondDir())]
            edges_list.append((i, j))
            edge_features_list.append(edge_feature)
            edges_list.append((j, i))
            edge_features_list.append(edge_feature)

        # data.edge_index: Graph connectivity in COO format with shape [2, num_edges]
        edge_index = torch.tensor(np.array(edges_list).T, dtype=torch.long)

        # data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
        edge_attr = torch.tensor(np.array(edge_features_list), dtype=torch.long)

    else:  # mol has no bonds
        num_bond_features = 2
        edge_index = torch.empty((2, 0), dtype=torch.long)
        edge_attr = torch.empty((0, num_bond_features), dtype=torch.long)

    # every CREST conformer gets its own mol object,
    # every mol object has only one RDKit conformer
    # ref: https://github.com/learningmatter-mit/geom/blob/master/tutorials/
    # conformer = mol.GetConformers()[0]
    # positions = conformer.GetPositions()
    # positions = torch.Tensor(positions)

    # data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
    return x, edge_index, edge_attr







ROOT_DIR = ".."


if __name__ == '__main__':
    os.makedirs(ROOT_DIR+"/data/Enamine50k/screen", exist_ok=True)
    os.makedirs(ROOT_DIR+"/data/Enamine50k/test", exist_ok=True)
    os.makedirs(ROOT_DIR+"/data/EnamineHTS/screen", exist_ok=True)
    os.makedirs(ROOT_DIR+"/data/EnamineHTS/test", exist_ok=True)
    # os.makedirs(ROOT_DIR+"/data/AmpC/screen", exist_ok=True)
    # os.makedirs(ROOT_DIR+"/data/AmpC/test", exist_ok=True)


    # df = pd.read_csv(ROOT_DIR+"/data/AmpC/original/AmpC_screen_table.csv") # 99459561, no_score 3245355

    # df.replace({"dockscore":{"no_score":999999}}, inplace=True)
    # df['dockscore'] = df["dockscore"].map(lambda x:float(x))

    # active = df.nsmallest(50000, 'dockscore', keep='first')
    # active_smiles = active["smiles"].to_list()
    # active_smiles_lines = [s+"\n" for s in active_smiles]

    # df.drop(index=list(active.index), inplace=True)
    # smiles = df["smiles"].to_list()
    # smiles_lines = [s+"\n" for s in smiles]

    # with open(ROOT_DIR+"/data/AmpC/original/actives.smi", "w") as file:
    #     file.writelines(active_smiles_lines)
    # with open(ROOT_DIR+"/data/AmpC/original/inactives.smi", "w") as file:
    #     file.writelines(smiles_lines)
    
    # for dataset in ["AmpC"]:
    #     df = get_data(dataset=dataset)
    #     df_screen, df_test = split_data(df, screen_size=df.shape[0]-2, test_size=2, dataset=dataset)

    #     MasterDataset(name='screen', df=df_screen, overwrite=True, dataset=dataset)
    #     MasterDataset(name='test', df=df_test, overwrite=True, dataset=dataset)




    df = pd.read_csv(ROOT_DIR+"/data/Enamine50k/original/Enamine50k_scores.csv")
    smiles = df["smiles"].to_list()
    smiles_lines = [s+"\n" for s in smiles]

    with open(ROOT_DIR+"/data/Enamine50k/original/actives.smi", "w") as file:
        file.writelines(smiles_lines[0:500])
    with open(ROOT_DIR+"/data/Enamine50k/original/inactives.smi", "w") as file:
        file.writelines(smiles_lines[501:])

    for dataset in ["Enamine50k"]:
        df = get_data(dataset=dataset)
        df_screen, df_test = split_data(df, screen_size=df.shape[0]-2, test_size=2, dataset=dataset)

        MasterDataset(name='screen', df=df_screen, overwrite=True, dataset=dataset)
        MasterDataset(name='test', df=df_test, overwrite=True, dataset=dataset)



    df = pd.read_csv(ROOT_DIR+"/data/EnamineHTS/original/EnamineHTS_scores.csv")
    smiles = df["smiles"].to_list()
    smiles_lines = [s+"\n" for s in smiles]

    with open(ROOT_DIR+"/data/EnamineHTS/original/actives.smi", "w") as file:
        file.writelines(smiles_lines[0:1000])
    with open(ROOT_DIR+"/data/EnamineHTS/original/inactives.smi", "w") as file:
        file.writelines(smiles_lines[1001:])

    for dataset in ["EnamineHTS"]:
        df = get_data(dataset=dataset)
        df_screen, df_test = split_data(df, screen_size=df.shape[0]-2, test_size=2, dataset=dataset)

        MasterDataset(name='screen', df=df_screen, overwrite=True, dataset=dataset)
        MasterDataset(name='test', df=df_test, overwrite=True, dataset=dataset)







for dataset in ["Enamine50k", "EnamineHTS"]:
    for usage in ["screen", "test"]:

        fp = torch.load(os.path.join(ROOT_DIR, "data", dataset, usage, "x"))
        graph_list = torch.load(os.path.join(ROOT_DIR, "data", dataset, usage, "graphs"))
        graph2_list = []


        for i, graph in tqdm(enumerate(graph_list)):
            graph.fp = torch.tensor([fp[i]], dtype=torch.float32)

            mol = Chem.MolFromSmiles(graph.smiles, sanitize=True)
            xp, edgep_index, edgep_attr = mol_to_graph_data_obj_simple_3D(mol)
            graph.xp = xp
            graph.edgep_index = edgep_index
            graph.edgep_attr = edgep_attr

            graph2_list.append(graph)


        torch.save(graph2_list, os.path.join(ROOT_DIR, "data", dataset, usage, "graphs2"), pickle_protocol=4)










