# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import numpy
import pandas
from matplotlib import pyplot as plt
from mol2graph.utils import get_num_atoms_and_bonds
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess


# --------
# Weights:
# --------
class CVPP_addWeights_Harmonic(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, k = 100.0, seq_id_col = '', mol_id_col = '', use_count_seq = True, use_count_mol = True):
        name = None # 'weighted_Harmonic'
        super(CVPP_addWeights_Harmonic, self).__init__(name, data_dir)
        self.k = k
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.col_name = 'harmonic'
        self.weight_col_name = self.col_name + '_weight'

        self.use_count_seq = use_count_seq 
        self.use_count_mol = use_count_mol

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'func_weight_name' : 'reciprocal_harmonic_mean_k100',
                'k' : self.k}

    def _count_occurences(self, data):
        count_seq = data.groupby(by = self.seq_id_col).count()[self.mol_id_col]
        count_seq.name = 'count_seq'

        from matplotlib import pyplot as plt

        if not self.use_count_seq:
            count_seq = pandas.Series(data=1, index=count_seq.index, dtype=count_seq.dtype, name=count_seq.name)

        count_mol = data.groupby(by = self.mol_id_col).count()[self.seq_id_col]
        count_mol.name = 'count_mol'

        if not self.use_count_mol:
            count_mol = pandas.Series(data=1, index=count_mol.index, dtype=count_mol.dtype, name=count_mol.name)
        
        return count_seq, count_mol

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        return 0.5*(self.k/float(x['count_seq']) + self.k/float(x['count_mol']))

    def _postprocess(self, data, count_seq, count_mol):
        """
        """
        data = data.join(count_seq, on = self.seq_id_col, how = 'inner')
        data = data.join(count_mol, on = self.mol_id_col, how = 'inner')
        data[self.weight_col_name] = data.apply(self._weight_function, axis = 1)
        data.drop('count_seq', axis=1, inplace = True)
        data.drop('count_mol', axis=1, inplace = True)
        return data

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        data = []
        if not data_train.empty:
            data.append(data_train)
        if not data_valid.empty:
            data.append(data_valid)

        data = pandas.concat(data)
        count_seq, count_mol = self._count_occurences(data)

        if not data_train.empty:
            data_train = self._postprocess(data_train, count_seq = count_seq, count_mol = count_mol)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid, count_seq = count_seq, count_mol = count_mol)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None
        

class CVPP_addWeights_LogHarmonic(CVPP_addWeights_Harmonic):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, k = 100.0, seq_id_col = 'seq_id', mol_id_col = 'mol_id', use_count_seq = True, use_count_mol = True):
        name = None # 'weighted_LogHarmonic'
        super(CVPP_addWeights_Harmonic, self).__init__(name, data_dir)
        self.k = k
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.col_name = 'logHarmonic'
        self.weight_col_name = self.col_name + '_weight'

        self.use_count_seq = use_count_seq 
        self.use_count_mol = use_count_mol

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        return numpy.log(0.5*(self.k/float(x['count_seq']) + self.k/float(x['count_mol'])) + 1)