# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import os
from shutil import copy

from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess


# ----------------
# Combine Weights:
# ----------------
class CVPP_combineWeights(BaseCVPostProcess):
    """
    Combine weight columns by multiplication.
    """
    def __init__(self, data_dir, weight_cols):
        weight_cols.sort()
        name = None
        super(CVPP_combineWeights, self).__init__(name, data_dir)

        self.col_name = 'combined_weight_' + '_'.join([col.replace('_weight', '') for col in weight_cols])

        self.weight_cols = weight_cols

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'data_dir' : self.data_dir,
                'weight_cols' : self.weight_cols}

    def _postprocess(self, data):
        data[self.col_name] = data[self.weight_cols].prod(axis = 1, skipna = False)
        return data

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        if not data_train.empty:
            data_train = self._postprocess(data_train)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None