# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import numpy
import pandas
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess


# ----------------------
# NOTE: Testing weights:
# ----------------------
class CVPP_addWeights_Harmonic(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, k = 100.0, seq_id_col = 'seq_id', mol_id_col = 'mol_id'):
        name = None
        super(CVPP_addWeights_Harmonic, self).__init__(name, data_dir)
        self.k = k
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.col_name = 'harmonic'
        self.weight_col_name = self.col_name + '_weight'

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'col_name' : self.col_name,
                'k' : self.k,
                'seq_id_col' : self.seq_id_col,
                'mol_id_col' : self.mol_id_col}

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        return 0.5*(self.k/float(x['count_seq']) + self.k/float(x['count_mol']))

    def _postprocess(self, data):
        # raise NotImplementedError('Behaviour change: weights must be added to separate columns in the data.')
        count_seq = data.groupby(by = self.seq_id_col).count()[self.mol_id_col]
        count_seq.name = 'count_seq'

        count_mol = data.groupby(by = self.mol_id_col).count()[self.seq_id_col]
        count_mol.name = 'count_mol'

        data = data.join(count_seq, on = self.seq_id_col, how = 'inner')
        data = data.join(count_mol, on = self.mol_id_col, how = 'inner')
        data[self.weight_col_name] = data.apply(self._weight_function, axis = 1)
        data.drop('count_seq', axis=1, inplace = True)
        data.drop('count_mol', axis=1, inplace = True)
        return data

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        if not data_train.empty:
            data_train = self._postprocess(data_train)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None
        

class CVPP_addWeights_LogHarmonic(CVPP_addWeights_Harmonic):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, k = 100.0, seq_id_col = 'seq_id', mol_id_col = 'mol_id'):
        name = None # 'weighted_LogHarmonic'
        super(CVPP_addWeights_Harmonic, self).__init__(name, data_dir)
        self.k = k
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.col_name = 'logHarmonic'
        self.weight_col_name = self.col_name + '_weight'

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        return numpy.log(0.5*(self.k/float(x['count_seq']) + self.k/float(x['count_mol'])) + 1)