import os
import pandas
from shutil import copy
from sklearn.model_selection import train_test_split

from ProtLig_GPCRclassA.base_cross_validation import BaseCVSplit

class TrainOnly(BaseCVSplit):
    """
    For convenience when createing training set only.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        super(TrainOnly, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'train_only'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        # copy(src = os.path.join(self.data_dir, 'CV_data_hparams.json'), dst = os.path.join(self.working_dir, 'CV_data_hparams.json'))

        return full_data

    def func_split_data(self, data, seed, **kwargs):
        """
        This just copy paste full_data to data_test

        Paramters:
        ----------
        data : pandas.DataFrame
            dataframe returned by self.func_data 
        """
        return data, pandas.DataFrame([], columns=data.columns), pandas.DataFrame([], columns=data.columns)
    

class Random_split(BaseCVSplit):
    """
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        super(Random_split, self).__init__(data_dir, seed = seed, split_kwargs = split_kwargs)

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        return full_data

    def func_split_data(self, data, seed, **kwargs):
        """
        function that takes data as input and outputs test_data, validation_data
        and train_data dataframes.

        Needs to be overwriten by user. By default calls self.random_data.

        Paramters:
        ----------
        data : pandas.DataFrame
            dataframe returned by self.func_data 
        """
        return self.random_data(data, seed, **kwargs)



class EC50_Random(BaseCVSplit):
    """
    Take randomly subset of EC50 data.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        """
        super(EC50_Random, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)
        self.func_split_data_name = 'EC50_random_data'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        return full_data

    def func_split_data(self, data, seed, **kwargs):
        """
        ec50_random_split
        """
        valid_ratio = kwargs.get('valid_ratio', 0.1)
        test_ratio = kwargs.get('test_ratio', None)

        assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['parameter'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))
        # test split
        _, data_test = train_test_split(data_ec50, 
                                            test_size = test_ratio, 
                                            random_state = seed)
        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_train = data.loc[data.index.difference(data_test.index)]
        
        # valid split
        data_train, data_valid = train_test_split(data_train, 
                                            test_size = (valid_ratio*len(data_ec50))/len(data_train),
                                            random_state = seed)
        print(len(data_test[data_test['responsive'] == 1]))
        print(len(data_test[data_test['responsive'] == 0]))
        return data_train, data_valid, data_test
    


class EC50_Random_test_and_valid(EC50_Random):
    """
    Take randomly subset of EC50 data.
    """
    def func_split_data(self, data, seed, **kwargs):
        """
        ec50_random_split
        """
        valid_ratio = kwargs.get('valid_ratio', 0.1)
        test_ratio = kwargs.get('test_ratio', None)

        assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['parameter'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))
        # test split
        data_ec50_rest, data_test = train_test_split(data_ec50, 
                                            test_size = test_ratio, 
                                            random_state = seed)
        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)


        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]
        
        print(len(data_test[data_test['responsive'] == 1]))
        print(len(data_test[data_test['responsive'] == 0]))
        return data_train, data_valid, data_test


class EC50_Random_valid_only(EC50_Random):
    """
    Take randomly subset of EC50 data.
    """
    def func_split_data(self, data, seed, **kwargs):
        """
        ec50_random_split
        """
        valid_ratio = kwargs.get('valid_ratio', 0.1)
        test_ratio = None

        assert (valid_ratio > 0.0) and (valid_ratio < 1.0)

        data_ec50 = data[data['parameter'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))
        # test split
        _, data_valid = train_test_split(data_ec50, 
                                        test_size = valid_ratio, 
                                        random_state = seed)
        assert len(data_valid[data_valid['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_valid[data_valid['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_valid) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_train = data.loc[data.index.difference(data_valid.index)]
        
        data_test = pandas.DataFrame([], columns=data_train.columns)
        return data_train, data_valid, data_test