# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import pandas
from mol2graph.utils import get_num_atoms_and_bonds
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

# --------
# Size cut
# --------
class CVPP_SizeCut(BaseCVPostProcess):
    def __init__(self, data_dir, n_node_thresholds = [32], n_edge_thresholds = [64], mols_csv = None, mol_id_col = 'Drug_ID', mol_col = 'Drug', bond_multiplier = 2, auxiliary_data_path = None):
        """
        Notes:
        ------
        n_node_threshold = 32, n_edge_threshold = 64 was chosen so that we don't loose too much (around ~43 unique molecules/mixtures, ~4000 pairs) but the graphs are small enough.

        data with num_atoms==threshold is included in data_big. 
        """
        name = 'size_cut_' + mol_col
        super(CVPP_SizeCut, self).__init__(name, data_dir)
        self.n_node_thresholds = n_node_thresholds
        self.n_edge_thresholds = n_edge_thresholds
        
        # NOTE: thresholds needs to be sorted.
        self.n_node_thresholds.sort()
        self.n_edge_thresholds.sort()
        
        if mols_csv is None:
            mols_csv = os.path.join(self.data_dir, 'mols.csv')

        _, self.mols_csv_name = os.path.split(mols_csv)
        self.mols_csv = mols_csv

        self.mol_id_col = mol_id_col
        self.mol_col = mol_col
        self.bond_multiplier = bond_multiplier # Because of directed edges.

        self.auxiliary_data_path = auxiliary_data_path # To get how many pairs are in each group.

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'n_node_thresholds' : self.n_node_thresholds,
                'n_edge_thresholds' : self.n_edge_thresholds,
                'mols_csv' : self.mols_csv,
                'auxiliary_data_path' : self.auxiliary_data_path}
    
    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols

    def load_auxiliary(self):
        auxiliary = {}
        auxiliary['stats_data'] = pandas.read_csv(self.auxiliary_data_path['stats_data'], sep = ';')
        return auxiliary

    @staticmethod
    def _test_small(x, n_node_threshold, n_edge_threshold, bond_multiplier):
        num_atoms, num_bonds = get_num_atoms_and_bonds(x)        
        return num_atoms < n_node_threshold and bond_multiplier * num_bonds < n_edge_threshold

    def _postprocess(self, mols):
        """
        """
        mols_datas = {}
        mols_datas_cumulative = {}
        available_idx = mols.index
        _idx_cumulative = pandas.Index([], name=mols.index.name)
        for i in range(len(self.n_node_thresholds)):
            n_node_tr = self.n_node_thresholds[i]
            n_edge_tr = self.n_edge_thresholds[i]
            _name = 'node' + str(n_node_tr) + '_' + 'edge' + str(n_edge_tr)

            _idx = mols[mols[self.mol_col].apply(lambda x: self._test_small(x, n_node_tr, n_edge_tr, self.bond_multiplier))].index
            _idx = _idx.intersection(available_idx)

            mols_datas[_name] = mols.loc[_idx]

            _name_cumulative = 'nodeUPTO' + str(n_node_tr) + '_' + 'edgeUPTO' + str(n_edge_tr)
            _idx_cumulative = _idx_cumulative.union(_idx)

            mols_datas_cumulative[_name_cumulative] = mols.loc[_idx_cumulative]

            print('Num of unique molecules in {}: {}'.format(_name, len(mols_datas[_name].index)))
            print('Num of unique molecules in cumulative {}: {}'.format(_name_cumulative, len(mols_datas_cumulative[_name_cumulative].index)))
            available_idx = available_idx.difference(mols_datas[_name].index)
        
        _name = 'reminder'
        mols_datas[_name] = mols.loc[available_idx]
        print('Num of unique molecules in {}: {}'.format(_name, len(available_idx)))
        return mols_datas, mols_datas_cumulative

    def _print_num_of_affected_pairs(self, mols_datas, auxiliary):
        stats_data = auxiliary['stats_data']

        for name in mols_datas.keys():
            _data = stats_data[stats_data[self.mol_id_col].isin(mols_datas[name].index)]
            print('Num of pairs in {}: {}'.format(name, len(_data)))

        return None

    def postprocess(self):
        mols = self.load_mols()

        mols_datas, mols_datas_cumulative = self._postprocess(mols)
        if self.auxiliary_data_path is not None:
            auxiliary = self.load_auxiliary()
            self._print_num_of_affected_pairs(mols_datas, auxiliary)
            self._print_num_of_affected_pairs(mols_datas_cumulative, auxiliary)

        for name in mols_datas.keys():
            mols_datas[name].to_csv(os.path.join(self.working_dir, 'mols_' + name + '.csv'), sep=';', index = True, header = True)

        for name in mols_datas_cumulative.keys():
            mols_datas_cumulative[name].to_csv(os.path.join(self.working_dir, 'mols_' + name + '.csv'), sep=';', index = True, header = True)

        self.save_hparams()
        return None