import os
import pandas
import json
from sklearn.model_selection import train_test_split
from shutil import copy
import numpy

import itertools
import multiprocessing
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import DataStructs
import hdbscan
from sklearn.cluster import AgglomerativeClustering

from ProtLig_GPCRclassA.base_cross_validation import BaseCrossValidation, BaseCVPreProcess, BaseCVSplit
from ProtLig_GPCRclassA.datasets.ORligand.utils import *

class InadequateTestSetSizeError(Exception):
    pass


class TestOnly(BaseCVSplit):
    """
    For convenience when createing test set only.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        super(TestOnly, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'test_only'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        # copy(src = os.path.join(self.data_dir, 'CV_data_hparams.json'), dst = os.path.join(self.working_dir, 'CV_data_hparams.json'))

        return full_data

    def func_split_data(self, data, seed, **kwargs):
        """
        This just copy paste full_data to data_test

        Paramters:
        ----------
        data : pandas.DataFrame
            dataframe returned by self.func_data 
        """
        return pandas.DataFrame([], columns=data.columns), pandas.DataFrame([], columns=data.columns), data


class TrainOnly(BaseCVSplit):
    """
    For convenience when createing training set only.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        super(TrainOnly, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'train_only'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        # copy(src = os.path.join(self.data_dir, 'CV_data_hparams.json'), dst = os.path.join(self.working_dir, 'CV_data_hparams.json'))

        return full_data

    def func_split_data(self, data, seed, **kwargs):
        """
        This just copy paste full_data to data_test

        Paramters:
        ----------
        data : pandas.DataFrame
            dataframe returned by self.func_data 
        """
        return data, pandas.DataFrame([], columns=data.columns), pandas.DataFrame([], columns=data.columns)
    

class EC50_Random(BaseCVSplit):
    """
    Take randomly subset of EC50 data.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        super(EC50_Random, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'EC50_random_data'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))

        return full_data


    def func_split_data(self, data, seed, **kwargs):
        """
        ec50_random_split
        """
        valid_ratio = kwargs.get('valid_ratio', 0.1)
        test_ratio = kwargs.get('test_ratio', None)

        assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))
        # test split
        _, data_test = train_test_split(data_ec50, 
                                            test_size = test_ratio, 
                                            random_state = seed)
        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_train = data.loc[data.index.difference(data_test.index)]
        
        # valid split
        data_train, data_valid = train_test_split(data_train, 
                                            test_size = (valid_ratio*len(data_ec50))/len(data_train),
                                            random_state = seed)
        print(len(data_test[data_test['responsive'] == 1]))
        print(len(data_test[data_test['responsive'] == 0]))
        return data_train, data_valid, data_test


class EC50_Random_test_and_valid(EC50_Random):
    """
    Take randomly subset of EC50 data.
    """
    def func_split_data(self, data, seed, **kwargs):
        """
        ec50_random_split
        """
        valid_ratio = kwargs.get('valid_ratio', 0.1)
        test_ratio = kwargs.get('test_ratio', None)

        assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))
        # test split
        data_ec50_rest, data_test = train_test_split(data_ec50, 
                                            test_size = test_ratio, 
                                            random_state = seed)
        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)


        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]
        
        print(len(data_test[data_test['responsive'] == 1]))
        print(len(data_test[data_test['responsive'] == 0]))
        return data_train, data_valid, data_test