# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import numpy
import pandas
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

# --------------------------
# NOTE: Data quality weight:
# --------------------------
class CVPP_addWeights_DataQuality(BaseCVPostProcess):
    """
    Weighting samples based on screening quality (Primary/Secondary/EC50) and combining the weights with others.

    Notes:
    ------
    This weight is indepandant on split.
    """
    def __init__(self, data_dir, label_col = 'responsive', data_quality_col = 'data_quality', auxiliary_data_path = None):
        name = None
        super(CVPP_addWeights_DataQuality, self).__init__(name, data_dir)

        self.col_name = 'data_quality'
        self.weight_col_name = self.col_name + '_weight'
        
        self.label_col = label_col
        self.data_quality_col = data_quality_col
        self.auxiliary_data_path = auxiliary_data_path

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'data_dir' : self.data_dir,
                'label_col' : self.label_col,
                'data_quality_col' : self.data_quality_col,
                'auxiliary_data_path' : self.auxiliary_data_path}

    def load_auxiliary(self):
        auxiliary = {}
        with open(self.auxiliary_data_path['screening_confidence_probs'], 'r') as jsonfile:
            auxiliary['screening_confidence_probs'] = json.load(jsonfile)

        with open(os.path.join(self.working_dir, 'addWeights_DataQuality_auxiliary.json'), 'w') as jsonfile:
            json.dump(auxiliary, jsonfile)

        return auxiliary

    @staticmethod
    def _weight_screening(row, probs, label_col, data_quality_col):
        if row[label_col] == 1:
            if row[data_quality_col] == 'primaryScreening':
                weight = probs['posEC50_if_posPrimary']
            elif row[data_quality_col] == 'secondaryScreening':
                weight = probs['posEC50_if_posSecondary']
            elif row[data_quality_col] == 'ec50':
                weight = 1.0
        elif row[label_col] == 0:
            if row[data_quality_col] == 'primaryScreening':
                weight = probs['negEC50_if_negPrimary']
            elif row[data_quality_col] == 'secondaryScreening':
                weight = probs['negEC50_if_negSecondary']
            elif row[data_quality_col] == 'ec50':
                weight = 1.0
        return weight

    def _postprocess(self, data, auxilary):
        data[self.weight_col_name] = data.apply(lambda x: self._weight_screening(x, auxilary['screening_confidence_probs'], label_col = self.label_col, data_quality_col = self.data_quality_col), axis = 1)
        return data

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()
        auxilary = self.load_auxiliary()

        if not data_train.empty:
            data_train = self._postprocess(data_train, auxilary)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid, auxilary)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test, auxilary)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None