import os
import pandas
import json
from sklearn.model_selection import train_test_split
import numpy

from ProtLig_GPCRclassA.base_cross_validation import BaseCrossValidation, BaseCVPreProcess, BaseCVSplit

try:
    from tdc.multi_pred import DTI
except ModuleNotFoundError as e:
    print('Module tdc is not found. KIBA CV_pairs_12_1 will fail.')


class CV_pairs_HyperAttentionDTI(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_pairs_HyperAttentionDTI, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        print('\nWARNING: func_data_name in CV_pairs_HyperAttentionDTI used to be "all_HyperAttentionDTI"\n')

        self.read_data_name = 'HyperAttentionDTI_KIBA'
        self.func_data_name = 'all_haDTI'

    def read_data(self, data_path):
        raw_data = pandas.read_csv(data_path['data'], sep=' ')
        raw_data.columns = ['Drug_ID', 'Target_ID', 'Drug', 'Target', 'Y']
        return raw_data

    def _func_data(self, raw_data):
        data = raw_data

        mols = data[['Drug_ID', 'Drug']].copy()
        mols = mols[~mols['Drug_ID'].duplicated()]
        mols.set_index('Drug_ID', inplace = True)
        
        seqs = data[['Target_ID', 'Target']].copy()
        seqs = seqs[~seqs['Target_ID'].duplicated()]
        seqs.set_index('Target_ID', inplace = True)
        
        data.rename(columns = {'Y' : 'Responsive'}, inplace = True)
        data.set_index(['Drug_ID', 'Target_ID'], inplace = True)

        return data, seqs, mols

    def func_data(self, raw_data):
        """
        """
        data, seqs, mols = self._func_data(raw_data)

        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep=';', index = True, header = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep=';', index = True, header = True)

        return data


class CV_pairs_12_1(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_pairs_12_1, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        self.read_data_name = 'tdc_multi_pred_DTI_DAVIS'
        self.func_data_name = 'all_12_1'

    def read_data(self, data_path):
        raw_data = DTI(name = 'KIBA')
        return raw_data

    def _func_data(self, raw_data):
        data = raw_data
        data.binarize(threshold = 12.1, order = 'ascending') # https://jcheminf.biomedcentral.com/articles/10.1186/s13321-017-0209-z
        data = data.get_data()

        mols = data[['Drug_ID', 'Drug']].copy()
        mols = mols[~mols['Drug_ID'].duplicated()]
        mols.set_index('Drug_ID', inplace = True)
        
        seqs = data[['Target_ID', 'Target']].copy()
        seqs = seqs[~seqs['Target_ID'].duplicated()]
        seqs.set_index('Target_ID', inplace = True)
        
        data.rename(columns = {'Y' : 'Responsive'}, inplace = True)
        data.set_index(['Drug_ID', 'Target_ID'], inplace = True)

        return data, seqs, mols
    
    def func_data(self, raw_data):
        """
        """
        data, seqs, mols = self._func_data(raw_data)

        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep=';', index = True, header = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep=';', index = True, header = True)

        return data