# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy
import datetime

import pandas
from mol2graph.utils import get_num_atoms_and_bonds
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

# -----------------
# Discard mixtures:
# -----------------
class CVPP_MixtureDiscard(BaseCVPostProcess):
    def __init__(self, data_dir, mols_csv = None, mol_id_col = 'mol_id', auxiliary_data_path = None):
        """
        Notes:
        ------
        
        """
        name = 'discard_mix'
        super(CVPP_MixtureDiscard, self).__init__(name, data_dir)

        self._data_name = 'mix'

        if mols_csv is None:
            mols_csv = os.path.join(self.data_dir, 'mols.csv')

        _, self.mols_csv_name = os.path.split(mols_csv)
        self.mols_csv = mols_csv

        self.mol_id_col = mol_id_col
        
        self.auxiliary_data_path = auxiliary_data_path # To get how many pairs are in each group.
    
    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'mol_id_col' : self.mol_id_col,
                'mols_csv' : self.mols_csv,
                'auxiliary_data_path' : self.auxiliary_data_path}
    
    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols
    
    def load_auxiliary(self):
        auxiliary = {}
        auxiliary['stats_data'] = pandas.read_csv(self.auxiliary_data_path['stats_data'], sep = ';')
        return auxiliary
    
    def _postprocess(self, mols):
        """
        """
        mols_datas = {}

        _exclude_cond = mols['mixture'] == 'mixture'

        mols_datas['exclude_' + self._data_name] = mols[_exclude_cond]
        mols_datas[self._data_name] = mols[~_exclude_cond]

        for _name in mols_datas.keys():
            print('Num of unique molecules in {}: {}'.format(_name, len(mols_datas[_name].index)))
        return mols_datas

    def _print_num_of_affected_pairs(self, mols_datas):
        auxiliary = self.load_auxiliary()
        stats_data = auxiliary['stats_data']

        for name in mols_datas.keys():
            _data = stats_data[stats_data[self.mol_id_col].isin(mols_datas[name].index)]
            print('Num of pairs in {}: {}'.format(name, len(_data)))

        return None
    
    def postprocess(self):
        mols = self.load_mols()

        mols_datas = self._postprocess(mols)
        if self.auxiliary_data_path is not None:
            self._print_num_of_affected_pairs(mols_datas)

        for name in mols_datas.keys():
            mols_datas[name].to_csv(os.path.join(self.working_dir, 'mols_' + name + '.csv'), sep=';', index = True, header = True)    

        self.save_hparams()
        return None



# -----------------------------
# Discard unreliable molecules:
# -----------------------------
class CVPP_MoleculeDiscard(BaseCVPostProcess):
    def __init__(self, data_dir, mol_keys = [], mols_csv = None, mol_id_col = 'mol_id', mol_key_col = 'inchi_key', auxiliary_data_path = None):
        """
        Notes:
        ------
        
        """
        if len(mol_keys) == 0:
            raise ValueError('No InChI Key provided.')
        _datetime = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        name = 'discard_by_list_{}'.format(_datetime)
        super(CVPP_MoleculeDiscard, self).__init__(name, data_dir)

        self._data_name = 'n{}'.format(len(mol_keys))

        if mols_csv is None:
            mols_csv = os.path.join(self.data_dir, 'mols.csv')

        _, self.mols_csv_name = os.path.split(mols_csv)
        self.mols_csv = mols_csv

        self.mol_keys = mol_keys
        self.mol_id_col = mol_id_col
        self.mol_key_col = mol_key_col
        
        self.auxiliary_data_path = auxiliary_data_path # To get how many pairs are in each group.

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'mol_keys' : self.mol_keys,
                'mol_id_col' : self.mol_id_col,
                'mol_key_col' : self.mol_key_col,
                'mols_csv' : self.mols_csv, 
                'auxiliary_data_path' : self.auxiliary_data_path}
    
    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols
    
    def load_auxiliary(self):
        auxiliary = {}
        auxiliary['stats_data'] = pandas.read_csv(self.auxiliary_data_path['stats_data'], sep = ';')
        return auxiliary
    
    def _postprocess(self, mols):
        """
        """
        mols_datas = {}

        _exclude_cond = mols[self.mol_key_col].isin(self.mol_keys)

        mols_datas['exclude_' + self._data_name] = mols[_exclude_cond]
        mols_datas[self._data_name] = mols[~_exclude_cond]

        for _name in mols_datas.keys():
            print('Num of unique molecules in {}: {}'.format(_name, len(mols_datas[_name].index)))
        return mols_datas
    
    
    def _print_num_of_affected_pairs(self, mols_datas):
        auxiliary = self.load_auxiliary()
        stats_data = auxiliary['stats_data']

        for name in mols_datas.keys():
            _data = stats_data[stats_data[self.mol_id_col].isin(mols_datas[name].index)]
            print('Num of pairs in {}: {}'.format(name, len(_data)))

        return None

    def postprocess(self):
        mols = self.load_mols()

        mols_datas = self._postprocess(mols)
        if self.auxiliary_data_path is not None:
            self._print_num_of_affected_pairs(mols_datas)

        for name in mols_datas.keys():
            mols_datas[name].to_csv(os.path.join(self.working_dir, 'mols_' + name + '.csv'), sep=';', index = True, header = True)    

        self.save_hparams()
        return None