import os
import pandas
from shutil import copy
from sklearn.model_selection import train_test_split
import numpy
import hdbscan

from ProtLig_GPCRclassA.base_cross_validation import BaseCVSplit, BaseCVPostProcess


class InadequateTestSetSizeError(Exception):
    pass


# -----------------------------------
# Deorphanization splits - Random OR:
# -----------------------------------
class EC50_LeaveSingleOut_OR_keep_screening(BaseCVSplit):
    """
    Take all occurences of randomly selected ORs, take their EC50 data as test set and keep screening occurencies in the training set.

    Here we wnat to investigate the effect of predicting pairs for new sequences but which can be similar to 
    the training ones. E.g. predicting pairs for a new mutant.

    Notes:
    ------
    We keep mutants of selected sequences.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        self.seq_id_col = split_kwargs.pop('seq_id_col')
        super(EC50_LeaveSingleOut_OR_keep_screening, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'EC50_LeaveSingleOut_OR'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0, low_memory=False)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        return full_data

    def get_ORs(self, data, min_n_ligands, portion, seed):
        """
        Parameters:
        -----------
        data : pandas.DataFrame

        min_n_ligands : int
            minimum number of ligands that a sequence must have in order to be considered for test set.

        portion : float
            perentage of sequences with enough ligands to include in test set.

        seed : int
            random seed for numpy.random.default_rng 

        Notes:
        ------
        If we wnat to put mutants and wild types together:
            # def _select_ORs_with_enough_ligands(df, min_n_ligands):
            #     if df['num_ligands'].sum() >= min_n_ligands:
            #         return df.index.values
            #     else:
            #         return float('nan')
            # _idx = _seqs.groupby('_Sequence').apply(lambda x: self._select_ORs_with_enough_ligands(x, min_n_ligands = min_n_ligands))
            # _idx = _idx.dropna()
            # n_idx = len(_idx)
            # print('Number of unique sequences with more than (or equal) {} ligands: {} (mutants added to wild type)'.format(min_n_ligands, n_idx))
        """
        num_ligands = data.groupby(self.seq_id_col).apply(lambda x: x['responsive'].sum())
        num_ligands.name = 'num_ligands'
        _idx = num_ligands[num_ligands >= min_n_ligands].index

        n_idx = len(_idx)
        print('Number of sequences with more than (or equal) {} ligands: {}'.format(min_n_ligands, n_idx))

        np_rng = numpy.random.default_rng(seed)
        _idx = np_rng.choice(_idx, size = int(numpy.ceil(portion*n_idx)))
        # _idx = pandas.Index(_idx, name = self.seq_id_col)
        return _idx

    def func_split_data(self, data, seed, **kwargs):
        """
        """
        valid_ratio = kwargs.pop('valid_ratio') # defualt suggestion: 0.1
        # test_ratio = kwargs.get('test_ratio', None)
        
        portion_seqs = kwargs.pop('portion_seqs') # defualt suggestion: 0.25
        min_n_ec50_ligands = kwargs.pop('min_n_ec50_ligands') # defualt suggestion: 2

        max_test_size = kwargs.pop('max_test_size') # defualt suggestion: 0.4
        min_test_size = kwargs.pop('min_test_size') # defualt suggestion: 0.1

        # assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))

        _max_test_size = max_test_size*len(data_ec50)
        _min_test_size = min_test_size*len(data_ec50)

        # Take proportions based only on ec50 data:
        _idx = self.get_ORs(data_ec50, min_n_ligands = min_n_ec50_ligands, portion = portion_seqs, seed = seed)
        print('Number of unique sequences to leave out: {}'.format(len(_idx)))
        data_test = data_ec50[data_ec50[self.seq_id_col].isin(_idx)]

        if len(data_test) > _max_test_size:
            raise InadequateTestSetSizeError('data_test is bigger than max size: Size: {}; max_size: {}'.format(len(data_test), _max_test_size))
        if len(data_test) < _min_test_size:
            raise InadequateTestSetSizeError('data_test is smaller than min size: Size: {}; min_size: {}'.format(len(data_test), _min_test_size))

        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_ec50_rest = data_ec50.loc[data_ec50.index.difference(data_test.index)]
        test_ratio = len(data_test)/len(data)

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)

        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]

        print('Number of positive in data_test: {}'.format(len(data_test[data_test['responsive'] == 1])))
        print('Number of negative in data_test: {}'.format(len(data_test[data_test['responsive'] == 0])))

        print('\nWARNING: EC50_LeaveSingleOut_OR needs to be followed by discarding screening in Post-processing!!\n')
        return data_train, data_valid, data_test



# ------------------------------------
# Deorphanization splits - Cluster OR:
# ------------------------------------
class EC50_LeaveClusterOut_OR_keep_screening(BaseCVSplit):
    """
    Leave cluster of ORs out based on some similarity measure.

    Here we want to investigate the effect of predicting pairs for entirely new sequences which are not similar to 
    the training ones.

    Notes:
    ------
    Mutants are in the test set.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data. This include seq_id_col and seqs_csv
        """
        self.auxiliary_data_path = split_kwargs.pop('auxiliary_data_path')
        self.seq_id_col = split_kwargs.pop('seq_id_col')

        super(EC50_LeaveClusterOut_OR_keep_screening, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'EC50_LeaveClusterOut_OR'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0, low_memory=False)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        return full_data

    def load_auxiliary(self):
        auxiliary = {}
        for key in self.auxiliary_data_path.keys():
            if self.auxiliary_data_path[key] is not None:
                auxiliary[key] = pandas.read_csv(self.auxiliary_data_path[key], sep=';', index_col = 0)
            else:
                auxiliary[key] = {}
        return auxiliary

    def get_seqs_pident_dist_from_blast_similarity_dataframe(self, seqs_blast_similarity):
        seqs_pident_dist = seqs_blast_similarity.reset_index().copy()
        seqs_pident_dist.set_index(['qseqid','sseqid'], inplace = True)

        seqs_pident_dist = (100.0 - seqs_pident_dist['pident'])/100.0
        print('\nWARNING: BLAST created duplicated records. Investigate.\n')
        seqs_pident_dist = seqs_pident_dist[~seqs_pident_dist.index.duplicated()] # TODO: BLAST created duplicated records. Investigate.

        seqs_pident_dist = seqs_pident_dist.unstack()
        return seqs_pident_dist

    def get_ORs(self, seqs, data, seqs_dist, min_n_ligands, min_n_seqs_with_enough_ligands_per_cluster, n_cluster_sample, seed, hdbscan_min_samples, hdbscan_min_cluster_size):
        """
        Parameters:
        -----------
        mols : pandas.DataFrame

        data : pandas.Dataframe

        min_n_ligands : int
            minimal number of sequences that a molecule needs to activate to be considered for test set.

        n_cluster_sample : int
            number of clusters to put to a test set.

        seed : int
            random seed for numpy.random.default_rng

        hdbscan_min_samples : int
            the number of samples in a neighbourhood for a point to be considered a core point in HDBSCAN.
            See: https://hdbscan.readthedocs.io/en/latest/api.html

        hdbscan_min_cluster_size : int
            the minimum size of clusters; single linkage splits that contain fewer points than this will be 
            considered points “falling out” of a cluster rather than a cluster splitting into two new clusters.
            See: https://hdbscan.readthedocs.io/en/latest/api.html
        """
        num_ligands = data.groupby(self.seq_id_col).apply(lambda x: x['responsive'].sum())
        num_ligands.name = 'num_ligands'
        _idx_ligands = num_ligands[num_ligands >= min_n_ligands].index

        n_idx = len(_idx_ligands)
        print('Number of ORs with more than (or equal) {} ligands: {}'.format(min_n_ligands, n_idx))

        # auxiliary = self.load_auxiliary()

        # def _rename(x):
        #     return x.split('+')[0].replace('>', '')
        # auxiliary['seq_dist'].rename(index = _rename, columns = _rename, inplace = True)

        # # Change seq_similarity ids to correspond to current ids in seqs:
        # seq_dist = auxiliary['seq_dist'].copy()
        # seq_dist_ids = auxiliary['seq_dist_ids'].copy()
        # seq_dist_ids.index.name = 'seq_dist_id'

        # _seqs = seqs.reset_index(drop=False)
        # _seqs = _seqs[['seq_id', 'mutated_Sequence']]
        # _seqs = _seqs.set_index('mutated_Sequence')

        # map_seq_ids = seq_dist_ids.join(_seqs, on = 'mutated_Sequence')
        # map_seq_ids = map_seq_ids['seq_id'].to_dict()

        # seq_dist.rename(index = map_seq_ids, columns = map_seq_ids, inplace = True)

        # Tanimoto Clustering - HDBSCAN:
        hdbscan_clustreing = hdbscan.HDBSCAN(min_samples = hdbscan_min_samples,
                                            min_cluster_size = hdbscan_min_cluster_size,
                                            metric = 'precomputed',
                                            cluster_selection_epsilon = 0.0)
        hdbscan_clustreing.fit(seqs_dist.values)
        _cluster_labels_map = pandas.Series(hdbscan_clustreing.labels_, index = seqs_dist.index, name = '_cluster_hdbscan')

        _cluster_labels_map = _cluster_labels_map[_cluster_labels_map >= 0] # NOTE: Ignoring "-1" cluster. 
        print('Number of unique clusters (not including "-1" cluster): {}'.format(len(_cluster_labels_map.unique())))
        # print(_cluster_labels.groupby(_cluster_labels).count())

        # Check how many sequences with enough ligands are inside the sampled clusters:
        _cluster_labels_seqs_with_enough_ligands = _cluster_labels_map.groupby(_cluster_labels_map).apply(lambda x: len(x.index.intersection(_idx_ligands)))
        _cluster_labels_seqs_with_enough_ligands = _cluster_labels_seqs_with_enough_ligands[_cluster_labels_seqs_with_enough_ligands >= min_n_seqs_with_enough_ligands_per_cluster]
        print('Number of unique clusters (not including "-1" cluster) with at least {} sequences with at least {} ligands: {}'.format(min_n_seqs_with_enough_ligands_per_cluster, min_n_ligands, len(_cluster_labels_seqs_with_enough_ligands)))

        # Sample clusters:
        np_rng = numpy.random.default_rng(seed)
        _sampled_cluster_labels = np_rng.choice(_cluster_labels_seqs_with_enough_ligands.index, size = n_cluster_sample)
        _idx = _cluster_labels_map[_cluster_labels_map.isin(_sampled_cluster_labels)].index

        return _idx

    def func_split_data(self, data, seed, **kwargs):
        """
        """
        valid_ratio = kwargs.get('valid_ratio') # defualt suggestion: 0.1
        # test_ratio = kwargs.get('test_ratio', None)

        min_n_ec50_ligands = kwargs.get('min_n_ec50_ligands') # defualt suggestion: 2
        min_n_seqs_with_enough_ligands_per_cluster = kwargs.get('min_n_seqs_with_enough_ligands_per_cluster') # defualt suggestion: 2
        n_cluster_sample = kwargs.get('n_cluster_sample') # defualt suggestion: 5

        max_test_size = kwargs.get('max_test_size') # defualt suggestion: 0.4
        min_test_size = kwargs.get('min_test_size') # defualt suggestion: 0.2

        hdbscan_min_samples = kwargs.get('hdbscan_min_samples') # defualt suggestion: 1
        hdbscan_min_cluster_size = kwargs.get('hdbscan_min_cluster_size') # defualt suggestion: 15

        # assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        auxiliary = self.load_auxiliary()
        seqs = auxiliary['seqs_csv']

        seqs_blast_similarity = auxiliary['seqs_blast_similarity_csv']
        
        seqs_pident_dist = self.get_seqs_pident_dist_from_blast_similarity_dataframe(seqs_blast_similarity)

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))

        _max_test_size = max_test_size*len(data_ec50)
        _min_test_size = min_test_size*len(data_ec50)

        _idx = self.get_ORs(seqs, data_ec50, 
                            seqs_dist = seqs_pident_dist,
                            min_n_ligands = min_n_ec50_ligands, # min_n_ec50_ligands
                            min_n_seqs_with_enough_ligands_per_cluster = min_n_seqs_with_enough_ligands_per_cluster,
                            n_cluster_sample = n_cluster_sample, 
                            seed = seed, 
                            hdbscan_min_samples = hdbscan_min_samples,
                            hdbscan_min_cluster_size = hdbscan_min_cluster_size)
        print('Number of unique sequences to leave out: {}'.format(len(_idx)))

        data_test = data_ec50[data_ec50[self.seq_id_col].isin(_idx)]

        if len(data_test) > _max_test_size:
            raise InadequateTestSetSizeError('data_test is bigger than max size: Size: {}; max_size: {}'.format(len(data_test), _max_test_size))
        if len(data_test) < _min_test_size:
            raise InadequateTestSetSizeError('data_test is smaller than min size: Size: {}; min_size: {}'.format(len(data_test), _min_test_size))

        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_ec50_rest = data_ec50.loc[data_ec50.index.difference(data_test.index)]
        test_ratio = len(data_test)/len(data)

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)

        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]

        print('Number of positive in data_test: {}'.format(len(data_test[data_test['responsive'] == 1])))
        print('Number of negative in data_test: {}'.format(len(data_test[data_test['responsive'] == 0])))

        print('\nWARNING: EC50_LeaveClusterOut_OR needs to be followed by discarding screening in Post-processing!!\n')
        return data_train, data_valid, data_test


class CVPP_OR_discard_screening(BaseCVPostProcess):
    """
    Discard screening occurencies of the sequences in the test set. 

    Here we wnat to investigate the effect of predicting pairs for new sequences which have never been tested before.
    """
    def __init__(self, data_dir, seq_id_col):
        """
        sequences that have length equal to lower_bound are kept.
        """
        name = None
        super(CVPP_OR_discard_screening, self).__init__(name, data_dir)

        self.seq_id_col = seq_id_col

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        print('Shape of data_train with screening:    {}'.format(data_train.shape))
        print('Shape of data_valid with screening:    {}'.format(data_valid.shape))

        data_train_no_screening = data_train[~data_train[self.seq_id_col].isin(data_test[self.seq_id_col])].copy()
        data_valid_no_screening = data_valid[~data_valid[self.seq_id_col].isin(data_test[self.seq_id_col])].copy()

        print('Shape of data_train without screening:    {}'.format(data_train_no_screening.shape))
        print('Shape of data_valid without screening:    {}'.format(data_valid_no_screening.shape))
        print('Shape of data_test:    {}'.format(data_test.shape))

        data_train_no_screening.to_csv(os.path.join(self.working_dir, 'data_train_no_screening.csv'), sep=';')
        data_valid_no_screening.to_csv(os.path.join(self.working_dir, 'data_valid_no_screening.csv'), sep=';')

        # self.save_hparams()
        return None