import os
import pandas
from shutil import copy
from sklearn.model_selection import train_test_split
import numpy
import json

import itertools
import multiprocessing
from rdkit import Chem
from rdkit.Chem import rdFingerprintGenerator
from rdkit.Chem import AllChem
from rdkit import DataStructs
import hdbscan
from sklearn.cluster import AgglomerativeClustering

from ProtLig_GPCRclassA.base_cross_validation import BaseCVSplit, BaseCVPostProcess


class InadequateTestSetSizeError(Exception):
    pass


# ------------------------------------------
# Deorphanization splits - Random Molecules:
# ------------------------------------------
class EC50_LeaveSingleOut_Mol_keep_screening(BaseCVSplit):
    """
    Take all occurences of randomly selected molecules, take their EC50 occurences as test. 
    Screening occurencies are NOT discarded here. This should be done in post-processing.

    Here we wnat to investigate the effect of predicting pairs for new molecules but which can be similar to 
    the training ones.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data.
        """
        self.mol_id_col = split_kwargs.pop('mol_id_col')
        super(EC50_LeaveSingleOut_Mol_keep_screening, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'EC50_LeaveSingleOut_Mol'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0, low_memory=False)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        # copy(src = os.path.join(self.data_dir, 'CV_data_hparams.json'), dst = os.path.join(self.working_dir, 'CV_data_hparams.json'))

        return full_data

    def get_Mols(self, data, min_n_ligands, portion, seed):
        """
        data : pandas.DataFrame

        min_n_ligands : int
            minimum number of sequences that a molecule needs to activate to be considered for test set.

        portion : float
            perentage of molecules with enough activated sequences to include in test set.

        seed : int
            random seed for numpy.random.default_rng
        """
        num_ligands = data.groupby(self.mol_id_col).apply(lambda x: x['responsive'].sum())
        num_ligands.name = 'num_ligands'
        _idx = num_ligands[num_ligands >= min_n_ligands].index

        n_idx = len(_idx)
        print('Number of molecules with more than (or equal) {} active ORs: {}'.format(min_n_ligands, n_idx))

        np_rng = numpy.random.default_rng(seed)
        _idx = np_rng.choice(_idx, size = int(numpy.ceil(portion*n_idx)))
        # _idx = pandas.Index(_idx, name = self.mol_id_col)
        return _idx

    def func_split_data(self, data, seed, **kwargs):
        """
        """
        valid_ratio = kwargs.pop('valid_ratio') # defualt suggestion: 0.1
        # test_ratio = kwargs.get('test_ratio', None)
        
        portion_mols = kwargs.pop('portion_mols') # defualt suggestion: 0.25
        min_n_ec50_ligands = kwargs.pop('min_n_ec50_ligands') # defualt suggestion: 5

        max_test_size = kwargs.pop('max_test_size') # defualt suggestion: 0.4
        min_test_size = kwargs.pop('min_test_size') # defualt suggestion: 0.1

        # assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))

        _max_test_size = max_test_size*len(data_ec50)
        _min_test_size = min_test_size*len(data_ec50)

        # Take proportions based only on ec50 data:
        _idx = self.get_Mols(data_ec50, min_n_ligands = min_n_ec50_ligands, portion = portion_mols, seed = seed)
        print('Number of unique molecules to leave out: {}'.format(len(_idx)))
        data_test = data_ec50[data_ec50[self.mol_id_col].isin(_idx)]

        if len(data_test) > _max_test_size:
            raise InadequateTestSetSizeError('data_test is bigger than max size: Size: {}; max_size: {}'.format(len(data_test), _max_test_size))
        if len(data_test) < _min_test_size:
            raise InadequateTestSetSizeError('data_test is smaller than min size: Size: {}; min_size: {}'.format(len(data_test), _min_test_size))

        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_ec50_rest = data_ec50.loc[data_ec50.index.difference(data_test.index)]
        test_ratio = len(data_test)/len(data)

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)

        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]

        print('Number of positive in data_test: {}'.format(len(data_test[data_test['responsive'] == 1])))
        print('Number of negative in data_test: {}'.format(len(data_test[data_test['responsive'] == 0])))

        print('\nWARNING: EC50_LeaveSingleOut_Mol needs to be followed by discarding screening in Post-processing!!\n')
        return data_train, data_valid, data_test




# -------------------------------------------
# Deorphanization splits - Cluster Molecules:
# -------------------------------------------
def tanimoto_similarity(smi1, smi2):
    mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=3,fpSize=2048)
    mol1 = Chem.MolFromSmiles(smi1)
    mol2 = Chem.MolFromSmiles(smi2)
    fp1 = mfpgen.GetFingerprint(mol1)
    fp2 = mfpgen.GetFingerprint(mol2)
    # s = round(DataStructs.TanimotoSimilarity(fp1,fp2),3)
    s = DataStructs.TanimotoSimilarity(fp1,fp2)
    return s

def dist_func(smiles):
    mol1, mol2 = smiles 
    id1, smi1 = mol1
    id2, smi2 = mol2
    return {'id_1': id1, 'id_2' : id2, 'Distance' : 1 - tanimoto_similarity(smi1, smi2)}  

class EC50_LeaveClusterOut_Mol_keep_screening(BaseCVSplit):
    """
    Leave cluster of molecules out based on some similarity measure.

    Here we want to investigate the effect of predicting pairs for entirely new molecules which are not similar to 
    the training ones.

    Notes:
    ------
    Mutants are in the test set.
    """
    def __init__(self, data_dir, seed = None, split_kwargs = {}):
        """
        Parameters:
        -----------
        data_dir : str
            directory containing full preprocessed data.

        data_path : str
            path to raw data.

        split_kwargs : dict:
            kwargs are passed to func_split_data. This include mol_id_col and mols_csv
        """
        self.auxiliary_data_path = split_kwargs.pop('auxiliary_data_path')
        self.mol_id_col = split_kwargs.pop('mol_id_col')

        super(EC50_LeaveClusterOut_Mol_keep_screening, self).__init__(data_dir = data_dir, seed = seed, split_kwargs = split_kwargs)

        self.func_split_data_name = 'EC50_LeaveClusterOut_Mol'

    def load_data(self):
        """
        function to read preprocessed full data into a DataFrame.
        """
        full_data = pandas.read_csv(os.path.join(self.data_dir, 'full_data.csv'), sep = self.sep, index_col = None, header = 0, low_memory=False)
        # Copy-paste data:
        # copy(src = os.path.join(self.data_dir, 'full_data.csv'), dst = os.path.join(self.working_dir, 'full_data.csv'))
        # copy(src = os.path.join(self.data_dir, 'mols.csv'), dst = os.path.join(self.working_dir, 'mols.csv'))
        # copy(src = os.path.join(self.data_dir, 'seqs.csv'), dst = os.path.join(self.working_dir, 'seqs.csv'))
        return full_data

    def load_auxiliary(self):
        auxiliary = {}
        for key in self.auxiliary_data_path.keys():
            if self.auxiliary_data_path[key] is not None:
                _, ext = os.path.splitext(self.auxiliary_data_path[key])
                if ext == '.csv':
                    auxiliary[key] = pandas.read_csv(self.auxiliary_data_path[key], sep=';', index_col = 0)
                elif ext == '.json':
                    with open(self.auxiliary_data_path[key], 'r') as jsonfile:
                        auxiliary[key] = json.load(jsonfile)
            else:
                auxiliary[key] = {}
        return auxiliary

    @staticmethod
    def _get_canonicalSMILES(row, _map_to_canonical):
        """
        """
        if row['inchi_key'] == row['inchi_key']:
            canonicalSMILES = _map_to_canonical[row['inchi_key']]
        else:
            canonicalSMILES = row['smiles']
        return canonicalSMILES

    def get_Mols(self, mols, data, map_inchikey_to_canonicalSMILES, min_n_ligands, min_n_mols_with_enough_actives_per_cluster, n_cluster_sample, seed, hdbscan_min_samples, hdbscan_min_cluster_size):
        """
        Parameters:
        -----------
        mols : pandas.DataFrame

        data : pandas.Dataframe

        min_n_ligands : int
            minimal number of sequences that a molecule needs to activate to be considered for test set.

        n_cluster_sample : int
            number of clusters to put to a test set.

        seed : int
            random seed for numpy.random.default_rng

        hdbscan_min_samples : int
            the number of samples in a neighbourhood for a point to be considered a core point in HDBSCAN.
            See: https://hdbscan.readthedocs.io/en/latest/api.html

        hdbscan_min_cluster_size : int
            the minimum size of clusters; single linkage splits that contain fewer points than this will be 
            considered points “falling out” of a cluster rather than a cluster splitting into two new clusters.
            See: https://hdbscan.readthedocs.io/en/latest/api.html
        """
        num_ligands = data.groupby(self.mol_id_col).apply(lambda x: x['responsive'].sum())
        num_ligands.name = 'num_ligands'
        _idx_ligands = num_ligands[num_ligands >= min_n_ligands].index

        n_idx = len(_idx_ligands)
        print('Number of molecules with more than (or equal) {} ligands: {}'.format(min_n_ligands, n_idx))

        _smiles = mols.apply(lambda x: self._get_canonicalSMILES(x, _map_to_canonical = map_inchikey_to_canonicalSMILES), axis = 1)     

        pool = multiprocessing.Pool(processes=20)
        _dist = pool.map(dist_func, itertools.combinations(_smiles.items(), 2))
        pool.close()
        pool.join()

        _df_dist = pandas.DataFrame(_dist)
        _df_dist.set_index(['id_1', 'id_2'], inplace = True)

        _data = _df_dist.copy()
        _data_T = _data.swaplevel()
        _data_diag = pandas.DataFrame(numpy.zeros(len(_smiles)), columns = ['Distance'], index = pandas.MultiIndex.from_arrays([_smiles.index, _smiles.index]))
        _data = pandas.concat([_data, _data_T, _data_diag])
        _data.sort_index(inplace = True)
        _dist_data = _data['Distance'].unstack()

        # Tanimoto Clustering - Agglomerative:
        # tanimoto_agg_clustering = AgglomerativeClustering(distance_threshold=None, n_clusters=20, affinity='precomputed', linkage = 'complete')
        # tanimoto_agg_clustering.fit(_dist_data.values)
        # _cluster_labels = pandas.Series(tanimoto_agg_clustering.labels_, index = _dist_data.index, name = 'tanimoto_cluster_agg')
        # print(_cluster_labels.groupby(_cluster_labels).count())
        # -------------------

        # Tanimoto Clustering - HDBSCAN:
        hdbscan_clustreing = hdbscan.HDBSCAN(min_samples = hdbscan_min_samples,
                                            min_cluster_size = hdbscan_min_cluster_size,
                                            metric = 'precomputed',
                                            cluster_selection_epsilon = 0.0)
        hdbscan_clustreing.fit(_dist_data.values)
        _cluster_labels_map = pandas.Series(hdbscan_clustreing.labels_, index = _dist_data.index, name = 'tanimoto_cluster_hdbscan')

        _cluster_labels_map = _cluster_labels_map[_cluster_labels_map >= 0] # NOTE: Ignoring "-1" cluster. 
        print('Number of unique clusters (not including "-1" cluster): {}'.format(len(_cluster_labels_map.unique())))
        # print(_cluster_labels.groupby(_cluster_labels).count())

        # Check how many molecules with enough active sequences are inside the sampled clusters:
        _cluster_labels_mols_with_enough_active = _cluster_labels_map.groupby(_cluster_labels_map).apply(lambda x: len(x.index.intersection(_idx_ligands)))
        _cluster_labels_mols_with_enough_active = _cluster_labels_mols_with_enough_active[_cluster_labels_mols_with_enough_active >= min_n_mols_with_enough_actives_per_cluster]
        print('Number of unique clusters (not including "-1" cluster) with at least {} molecules with at least {} active sequences: {}'.format(min_n_mols_with_enough_actives_per_cluster, min_n_ligands, len(_cluster_labels_mols_with_enough_active)))

        # Sample clusters:
        np_rng = numpy.random.default_rng(seed)
        _sampled_cluster_labels = np_rng.choice(_cluster_labels_mols_with_enough_active.index, size = n_cluster_sample)
        _idx = _cluster_labels_map[_cluster_labels_map.isin(_sampled_cluster_labels)].index

        return _idx

    def func_split_data(self, data, seed, **kwargs):
        """
        """
        valid_ratio = kwargs.get('valid_ratio') # defualt suggestion: 0.1
        # test_ratio = kwargs.get('test_ratio', None)

        min_n_ec50_ligands = kwargs.get('min_n_ec50_ligands') # defualt suggestion: 2
        min_n_mols_with_enough_actives_per_cluster = kwargs.get('min_n_mols_with_enough_actives_per_cluster') # defualt suggestion: 2
        n_cluster_sample = kwargs.get('n_cluster_sample') # defualt suggestion: 5

        max_test_size = kwargs.get('max_test_size') # defualt suggestion: 0.4
        min_test_size = kwargs.get('min_test_size') # defualt suggestion: 0.2

        hdbscan_min_samples = kwargs.get('hdbscan_min_samples') # defualt suggestion: 1
        hdbscan_min_cluster_size = kwargs.get('hdbscan_min_cluster_size') # defualt suggestion: 10

        # assert (test_ratio > 0.0) and (test_ratio < 1.0)
        assert (valid_ratio > 0.0) # and (valid_ratio < 1.0)

        auxiliary = self.load_auxiliary()
        mols = auxiliary['mols_csv']

        data_ec50 = data[data['data_quality'] == 'ec50']
        print('Number of EC50 measurements: Positive: {}, Negative: {}'.format(len(data_ec50[data_ec50['responsive'] == 1]), len(data_ec50[data_ec50['responsive'] == 0])))

        _max_test_size = max_test_size*len(data_ec50)
        _min_test_size = min_test_size*len(data_ec50)

        _idx = self.get_Mols(mols, data_ec50,
                            map_inchikey_to_canonicalSMILES = auxiliary['map_inchikey_to_canonicalSMILES'],
                            min_n_ligands = min_n_ec50_ligands,
                            min_n_mols_with_enough_actives_per_cluster = min_n_mols_with_enough_actives_per_cluster,
                            n_cluster_sample = n_cluster_sample, 
                            seed = seed, 
                            hdbscan_min_samples = hdbscan_min_samples,
                            hdbscan_min_cluster_size = hdbscan_min_cluster_size)
        print('Number of unique molecules to leave out: {}'.format(len(_idx)))

        data_test = data_ec50[data_ec50[self.mol_id_col].isin(_idx)]

        if len(data_test) > _max_test_size:
            raise InadequateTestSetSizeError('data_test is bigger than max size: Size: {}; max_size: {}'.format(len(data_test), _max_test_size))
        if len(data_test) < _min_test_size:
            raise InadequateTestSetSizeError('data_test is smaller than min size: Size: {}; min_size: {}'.format(len(data_test), _min_test_size))

        assert len(data_test[data_test['responsive'] == 1]) > 0.01*len(data_ec50)
        assert len(data_test[data_test['responsive'] == 0]) > 0.01*len(data_ec50)
        if len(data_test) > 0.5*len(data_ec50):
            raise ValueError('Test data is taking more than 50% of all EC50 data.')

        data_ec50_rest = data_ec50.loc[data_ec50.index.difference(data_test.index)]
        test_ratio = len(data_test)/len(data)

        # valid split
        _, data_valid = train_test_split(data_ec50_rest, 
                                        test_size = valid_ratio/(1-test_ratio),
                                        random_state = seed)

        data_train = data.loc[data.index.difference(data_test.index.union(data_valid.index))]

        print('Number of positive in data_test: {}'.format(len(data_test[data_test['responsive'] == 1])))
        print('Number of negative in data_test: {}'.format(len(data_test[data_test['responsive'] == 0])))

        print('\nWARNING: EC50_LeaveClusterOut_Mol needs to be followed by discarding screening in Post-processing!!\n')
        return data_train, data_valid, data_test


class CVPP_Mol_discard_screening(BaseCVPostProcess):
    """
    Discard screening occurencies of the sequences in the test set. 

    Here we wnat to investigate the effect of predicting pairs for new sequences which have never been tested before.
    """
    def __init__(self, data_dir, mol_id_col):
        """
        sequences that have length equal to lower_bound are kept.
        """
        name = None
        super(CVPP_Mol_discard_screening, self).__init__(name, data_dir)

        self.mol_id_col = mol_id_col

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        print('Shape of data_train with screening:    {}'.format(data_train.shape))
        print('Shape of data_valid with screening:    {}'.format(data_valid.shape))

        data_train_no_screening = data_train[~data_train[self.mol_id_col].isin(data_test[self.mol_id_col])].copy()
        data_valid_no_screening = data_valid[~data_valid[self.mol_id_col].isin(data_test[self.mol_id_col])].copy()

        print('Shape of data_train without screening:    {}'.format(data_train_no_screening.shape))
        print('Shape of data_valid without screening:    {}'.format(data_valid_no_screening.shape))
        print('Shape of data_test:    {}'.format(data_test.shape))

        data_train_no_screening.to_csv(os.path.join(self.working_dir, 'data_train_no_screening.csv'), sep=';')
        data_valid_no_screening.to_csv(os.path.join(self.working_dir, 'data_valid_no_screening.csv'), sep=';')

        # self.save_hparams()
        return None