# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import pandas
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

# ----------------------------
# NOTE: Crate class_dist.json: 
# ----------------------------
class CVPP_class_dist(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, mols_csv, seqs_csv, seq_id_col = 'Target_ID', mol_id_col = 'Drug_ID', label_col = 'responsive'):
        name = None
        super(CVPP_class_dist, self).__init__(name, data_dir)
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.label_col = label_col
        self.mols_csv = mols_csv
        self.seqs_csv = seqs_csv

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'seq_id_col' : self.seq_id_col,
                'mol_id_col' : self.mol_id_col,
                'label_col' : self.label_col,
                'mols_csv' : self.mols_csv,
                'seqs_csv' : self.seqs_csv}

    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols
    
    def load_seqs(self):
        mols = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col, header = 0)
        return mols   

    def postprocess(self):
        data_train, data_valid, _ = self.load_data()

        mols = self.load_mols()
        seqs = self.load_seqs()

        data = []
        if not data_train.empty:
            data.append(data_train)
        if not data_valid.empty:
            data.append(data_valid)

        data = pandas.concat(data)

        data = data[data[self.seq_id_col].isin(seqs.index)]
        data = data[data[self.mol_id_col].isin(mols.index)]
        col = data.dropna(axis = 'columns').columns[0]

        cls_dist = data.groupby(self.label_col).count()[col].to_dict()

        with open(os.path.join(self.working_dir, 'class_dist.json'), 'w') as outfile:
            json.dump(cls_dist, outfile)

        self.save_hparams(prefix = 'class_dist_')
        return None
    
# -------------------
# NOTE: Class weight:
# -------------------
class CVPP_addWeights_Class(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, auxiliary_data_path, class_col = 'responsive'):
        name = None # 'weights__LogHarmonic' + '_' + 'class'
        super(CVPP_addWeights_Class, self).__init__(name, data_dir)

        self.col_name = 'class'
        self.weight_col_name = self.col_name + '_weight'
        self.auxiliary_data_path = auxiliary_data_path

        self.class_col = class_col

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'data_dir' : self.data_dir,
                'auxiliary_data_path' : self.auxiliary_data_path}

    def load_auxiliary(self):
        auxiliary = {}
        with open(self.auxiliary_data_path['class_dist'], 'r') as jsonfile:
            auxiliary['class_dist'] = json.load(jsonfile)

        with open(os.path.join(self.working_dir, 'addWeights_Class_auxiliary.json'), 'w') as jsonfile:
            json.dump(auxiliary, jsonfile)

        return auxiliary

    def _postprocess(self, data, auxilary):
        # class_weight_map = {0 : 1.0, 1 : (auxilary['class_dist']['0']/auxilary['class_dist']['1'])}
        n_samples = sum(auxilary['class_dist'].values())
        n_classes = len(auxilary['class_dist'])
        class_weight_map = {str(key) : n_samples/(n_classes*auxilary['class_dist'][key]) for key in auxilary['class_dist'].keys()}

        # class_weight_map = {0 : 1.0 + (auxilary['class_dist']['1']/auxilary['class_dist']['0']), 1 : 1.0 + (auxilary['class_dist']['0']/auxilary['class_dist']['1'])}
        # data[self.weight_col_name] = data['Responsive'].map(class_weight_map)
        # data['sample_weight'] = data['LogHarmonic_weight'] * data['class_weight']

        data[self.weight_col_name] = data[self.class_col].astype(str).map(class_weight_map)
        return data

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()
        auxilary = self.load_auxiliary()

        if not data_train.empty:
            data_train = self._postprocess(data_train, auxilary)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid, auxilary)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test, auxilary)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None