import os
import pandas
import json
from sklearn.model_selection import train_test_split
import numpy

from ProtLig_GPCRclassA.base_cross_validation import BaseCrossValidation, BaseCVPreProcess, BaseCVSplit

try:
    from tdc.multi_pred import DTI
except ModuleNotFoundError as e:
    print('Module tdc is not found. DAVIS CV_pairs_30 and CV_pairs_log_5 will fail.')


class CV_pairs_HyperAttentionDTI(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_pairs_HyperAttentionDTI, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        print('\nWARNING: func_data_name in CV_pairs_HyperAttentionDTI used to be "all_HyperAttentionDTI"\n')

        self.read_data_name = 'HyperAttentionDTI_DAVIS'
        self.func_data_name = 'all_haDTI'

    def read_data(self, data_path):
        raw_data = pandas.read_csv(data_path['data'], sep=' ')
        raw_data.columns = ['Drug_ID', 'Target_ID', 'Drug', 'Target', 'Y']
        return raw_data

    def _func_data(self, raw_data):
        data = raw_data

        mols = data[['Drug_ID', 'Drug']].copy()
        mols = mols[~mols['Drug_ID'].duplicated()]
        mols.set_index('Drug_ID', inplace = True)
        
        seqs = data[['Target_ID', 'Target']].copy()
        seqs = seqs[~seqs['Target_ID'].duplicated()]
        seqs.set_index('Target_ID', inplace = True)
        
        data.rename(columns = {'Y' : 'Responsive'}, inplace = True)
        data.set_index(['Drug_ID', 'Target_ID'], inplace = True)

        return data, seqs, mols

    def func_data(self, raw_data):
        """
        """
        data, seqs, mols = self._func_data(raw_data)

        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep=';', index = True, header = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep=';', index = True, header = True)

        return data


class CV_pairs_30(BaseCVPreProcess):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_pairs_30, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        self.read_data_name = 'tdc_multi_pred_DTI_DAVIS'
        self.func_data_name = 'all_30'

    def read_data(self, data_path):
        raw_data = DTI(name = 'DAVIS')
        return raw_data
    
    def _prepare_data(self, raw_data):
        data = raw_data
        data.binarize(threshold = 30.0, order = 'descending')
        data = data.get_data()
        return data

    def _func_data(self, raw_data):
        data = self._prepare_data(raw_data)

        mols = data[['Drug_ID', 'Drug']].copy()
        mols = mols[~mols['Drug_ID'].duplicated()]
        mols.set_index('Drug_ID', inplace = True)
        
        seqs = data[['Target_ID', 'Target']].copy()
        seqs = seqs[~seqs['Target_ID'].duplicated()]
        seqs.set_index('Target_ID', inplace = True)
        
        data.rename(columns = {'Y' : 'Responsive'}, inplace = True)
        data.set_index(['Drug_ID', 'Target_ID'], inplace = True)

        return data, seqs, mols
    
    def func_data(self, raw_data):
        """
        """
        data, seqs, mols = self._func_data(raw_data)

        # Save sequences and molecules:
        seqs.to_csv(os.path.join(self.working_dir, 'seqs_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged seqs file.
        os.mkdir(os.path.join(self.working_dir, 'seqs'))
        seqs.to_csv(os.path.join(self.working_dir, 'seqs', 'seqs.csv'), sep=';', index = True, header = True)

        mols.to_csv(os.path.join(self.working_dir, 'mols_raw.csv'), sep=';', index = True, header = True) # NOTE: This is to have unchanged mols file.
        os.mkdir(os.path.join(self.working_dir, 'mols'))
        mols.to_csv(os.path.join(self.working_dir, 'mols', 'mols.csv'), sep=';', index = True, header = True)

        return data


class CV_pairs_log_5(CV_pairs_30):
    def __init__(self, base_working_dir, data_path): # , seed = None, split_kwargs = {}):
        super(CV_pairs_30, self).__init__(base_working_dir, data_path) # , seed, split_kwargs)

        self.read_data_name = 'tdc_multi_pred_DTI_DAVIS'
        self.func_data_name = 'all_log_5'

    def _prepare_data(self, raw_data):
        data = raw_data
        data.convert_to_log(form = 'binding')
        data.binarize(threshold = 5.0, order = 'ascending')
        data = data.get_data()
        return data