# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
# Adapted by TDC.

import os
import torch
import numpy as np
import pandas as pd
import pickle

def get_dataset_class(dataset_name):
    """Return the dataset class with the given name."""
    if dataset_name not in globals():
        raise NotImplementedError("Dataset not found: {}".format(dataset_name))
    return globals()[dataset_name]


def num_environments(dataset_name):
    return len(get_dataset_class(dataset_name).ENVIRONMENTS)


class MultipleDomainDataset:
    N_STEPS = 5001           # Default, subclasses may override
    CHECKPOINT_FREQ = 100    # Default, subclasses may override
    N_WORKERS = 0            # Default, subclasses may override
    ENVIRONMENTS = None      # Subclasses should override
    INPUT_SHAPE = None       # Subclasses should override

    def __getitem__(self, index):
        return self.datasets[index]

    def __len__(self):
        return len(self.datasets)


amino_char = ['?', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I', 'H', 'K', 'M', 'L', 'O',
       'N', 'Q', 'P', 'S', 'R', 'U', 'T', 'W', 'V', 'Y', 'X', 'Z']

smiles_char = ['?', '#', '%', ')', '(', '+', '-', '.', '1', '0', '3', '2', '5', '4',
       '7', '6', '9', '8', '=', 'A', 'C', 'B', 'E', 'D', 'G', 'F', 'I',
       'H', 'K', 'M', 'L', 'O', 'N', 'P', 'S', 'R', 'U', 'T', 'W', 'V',
       'Y', '[', 'Z', ']', '_', 'a', 'c', 'b', 'e', 'd', 'g', 'f', 'i',
       'h', 'm', 'l', 'o', 'n', 's', 'r', 'u', 't', 'y']

MAX_SEQ_PROTEIN = 1000
MAX_SEQ_DRUG = 100

from sklearn.preprocessing import OneHotEncoder
enc_protein = OneHotEncoder().fit(np.array(amino_char).reshape(-1, 1))
enc_drug = OneHotEncoder().fit(np.array(smiles_char).reshape(-1, 1))

def protein_2_embed(x):
	return enc_protein.transform(np.array(x).reshape(-1,1)).toarray().T

def drug_2_embed(x):
	return enc_drug.transform(np.array(x).reshape(-1,1)).toarray().T

def trans_protein(x):
	temp = list(x.upper())
	temp = [i if i in amino_char else '?' for i in temp]
	if len(temp) < MAX_SEQ_PROTEIN:
		temp = temp + ['?'] * (MAX_SEQ_PROTEIN-len(temp))
	else:
		temp = temp [:MAX_SEQ_PROTEIN]
	return temp

def trans_drug(x):
	temp = list(x)
	temp = [i if i in smiles_char else '?' for i in temp]
	if len(temp) < MAX_SEQ_DRUG:
		temp = temp + ['?'] * (MAX_SEQ_DRUG-len(temp))
	else:
		temp = temp [:MAX_SEQ_DRUG]
	return temp

from torch.utils import data

class dti_tensor_dataset(data.Dataset):

    def __init__(self, df):
        self.df = df
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        d = self.df.iloc[index].Drug_Enc
        t = self.df.iloc[index].Target_Enc
        
        d = drug_2_embed(d)
        t = protein_2_embed(t)
        
        y = self.df.iloc[index].Y
        return d, t, y

class TdcDtiDg(MultipleDomainDataset):
    def __init__(self, root, test_envs, hparams):
        super().__init__()
        
        ENVIRONMENTS = [str(i) for i in list(range(2013, 2022))]
        TRAIN_ENV = [str(i) for i in list(range(2013, 2019))]
        TEST_ENV = ['2019', '2020', '2021']
        #ENVIRONMENTS = [str(i) for i in list(range(2013, 2019))]
        #TRAIN_ENV = [str(i) for i in list(range(2013, 2017))]
        #TEST_ENV = ['2017', '2018']
        
        #TRAIN_ENV = ['2019', '2020']
        #TEST_ENV = ['2021']
        
        self.input_shape = [(26, 100), (63, 1000)] # drug and protein
        self.num_classes = 1
        
        if root is None:
            raise ValueError('Data directory not specified!')
       
        ## create a datasets object
        self.datasets = []
        from tdc import BenchmarkGroup
        self.group = BenchmarkGroup(name = 'DTI_DG_Group', path = root)
        benchmark = self.group.get('BindingDB_Patent') 
        train_val, test, name = benchmark['train_val'], benchmark['test'], benchmark['name']
        
        unique_drug = pd.Series(train_val['Drug'].unique()).apply(trans_drug)
        print(f'unique_drug.shape= {unique_drug.shape}')

        unique_dict_drug = dict(zip(train_val['Drug'].unique(), unique_drug))

        #print(f'unique_dict_drug.shape = {unique_dict_drug.shape}')

        train_val['Drug_Enc'] = [unique_dict_drug[i] for i in train_val['Drug']]

        unique_target = pd.Series(train_val['Target'].unique()).apply(trans_protein)
        unique_dict_target = dict(zip(train_val['Target'].unique(), unique_target))
        train_val['Target_Enc'] = [unique_dict_target[i] for i in train_val['Target']]

        for i in TRAIN_ENV:
            df_data = train_val[train_val.Year == int(i)]
            self.datasets.append(dti_tensor_dataset(df_data))
            print('Year ' + i + ' loaded...')

        unique_drug = pd.Series(test['Drug'].unique()).apply(trans_drug)
        unique_dict_drug = dict(zip(test['Drug'].unique(), unique_drug))
        test['Drug_Enc'] = [unique_dict_drug[i] for i in test['Drug']]

        unique_target = pd.Series(test['Target'].unique()).apply(trans_protein)
        unique_dict_target = dict(zip(test['Target'].unique(), unique_target))
        test['Target_Enc'] = [unique_dict_target[i] for i in test['Target']]
        
        print(f'test = {test.columns}')
        print(f'train_val.shape = {train_val.shape}, test.shape = {test.shape}')

        for i in TEST_ENV:
            df_data = test[test.Year == int(i)]
            self.datasets.append(dti_tensor_dataset(df_data))
            print('Year ' + i + ' loaded...')
