import pandas as pd
import numpy as np
import pickle
from tdc.multi_pred import DTI
from tqdm import tqdm

def get_splits(data):
    split = data.get_split()
    split = [split['train'], split['valid'], split['test']]
    split = pd.concat(split)
    return split

class LoadDKI():
    def __init__(
        self,
        data_dir
    ):
        dti_names = [
            'BindingDB_Kd',
            'BindingDB_IC50',
            'BindingDB_Ki'
        ]

        self.dti_dict = {}
        for name in dti_names:
            data = DTI(name=name, path=data_dir)
            data.convert_to_log(form = 'binding')
            self.dti_dict[name] = get_splits(data)
    
    def get_drug_dki(self, drug_file):
        pk_df = pd.read_csv("./data/xtended_data_all.csv")
        dti_df = pd.read_csv("./data/xtended_dti_data_all.csv")

        all_smiles = list(set(pk_df['Drug'].values.tolist() + dti_df['Drug'].values.tolist()))

        data = np.load(drug_file)
        data_smile_order = np.load("./data/pk_dti_fp_smiles.npy")

        print(len(all_smiles), " == ", len(data))

        drug_map = {}
        for i, (emb,s) in enumerate(zip(data, data_smile_order)):
            # if all_smiles[i] != s:
            #     print("oof")
            drug_map[s] = emb

        return drug_map

    def get_target_dki(self, target_file):
        prot_map = {}
        for _,v in self.dti_dict.items():
            for p_id, p in zip(v["Target_ID"], v["Target"]):
                prot_map[p_id] = p

        file = open(target_file, "rb")
        data = pickle.load(file)

        target_map = {}
        for p_id, emb in data.items():
            target_map[prot_map[p_id]] = emb

        return target_map