# -------------------------------------------------------------------------------------
# NOTE: All postprocess functions are not changing data_test (except CVPP for mixtures)
# -------------------------------------------------------------------------------------
import json
import os
from shutil import copy

import numpy
import pandas
from ProtLig_GPCRclassA.base_cross_validation import BaseCVPostProcess

# ----------------------------------------------------------------------------
# Create mols_broadness and seqs_broadness based on data_train and data_valid:
# ----------------------------------------------------------------------------
class CVPP_calculate_broadness(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, mols_csv, seqs_csv, mol_id_col, seq_id_col, label_col):
        name = None
        super(CVPP_calculate_broadness, self).__init__(name, data_dir)

        self.mols_csv = mols_csv
        self.seqs_csv = seqs_csv
        self.mol_id_col = mol_id_col
        self.seq_id_col = seq_id_col
        self.label_col = label_col

        self.mols_basename = os.path.splitext(os.path.basename(mols_csv))[0]
        self.seqs_basename = os.path.splitext(os.path.basename(seqs_csv))[0]

        self.suffix = self.mols_basename + '__' + self.seqs_basename

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'mols_csv' : self.mols_csv,
                'seqs_csv' : self.seqs_csv,
                'mol_id_col' : self.mol_id_col,
                'seq_id_col' : self.seq_id_col,
                'label_col' : self.label_col}

    def load_mols(self):
        mols = pandas.read_csv(self.mols_csv, sep = self.sep, index_col = self.mol_id_col, header = 0)
        return mols
    
    def load_seqs(self):
        mols = pandas.read_csv(self.seqs_csv, sep = self.sep, index_col = self.seq_id_col, header = 0)
        return mols

    @staticmethod
    def broadness_measure(x, cls):
        n_tested = x.shape[0]
        x_bool = x == cls
        return numpy.sum(x_bool).astype(numpy.float32) / n_tested

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()

        mols = self.load_mols()
        seqs = self.load_seqs()

        data = []
        if not data_train.empty:
            data.append(data_train)
        if not data_valid.empty:
            data.append(data_valid)

        data = pandas.concat(data)

        data = data[data[self.seq_id_col].isin(seqs.index)]
        data = data[data[self.mol_id_col].isin(mols.index)]

        mols_experiments = data.groupby(self.mol_id_col).apply(lambda x: numpy.array(x[self.label_col]))
        mols_experiments.name = 'experiments'
        mols_experiments = mols_experiments.to_frame()
        mols_experiments['broadness_cls0'] = mols_experiments['experiments'].apply(lambda x: self.broadness_measure(x, cls = 0))
        mols_experiments['broadness_cls1'] = mols_experiments['experiments'].apply(lambda x: self.broadness_measure(x, cls = 1))
        mols_experiments['count'] = mols_experiments['experiments'].apply(len)
        mols_experiments.drop(columns = ['experiments'], inplace = True)

        seqs_experiments = data.groupby(self.seq_id_col).apply(lambda x: numpy.array(x[self.label_col]))
        seqs_experiments.name = 'experiments'
        seqs_experiments = seqs_experiments.to_frame()
        seqs_experiments['broadness_cls0'] = seqs_experiments['experiments'].apply(lambda x: self.broadness_measure(x, cls = 0))
        seqs_experiments['broadness_cls1'] = seqs_experiments['experiments'].apply(lambda x: self.broadness_measure(x, cls = 1))
        seqs_experiments['count'] = seqs_experiments['experiments'].apply(len)
        seqs_experiments.drop(columns = ['experiments'], inplace = True)

        mols_experiments.to_csv(os.path.join(self.working_dir, 'mols_broadness__'+ self.suffix +'.csv'), sep=';')
        seqs_experiments.to_csv(os.path.join(self.working_dir, 'seqs_broadness__'+ self.suffix +'.csv'), sep=';')

        self.save_hparams(prefix = 'broandess__' + self.suffix + '_')
        return None
    

# -----------------------
# NOTE: Broadness weight:
# -----------------------
class CVPP_addWeights_BroadnessClipNoBias(BaseCVPostProcess):
    """
    Notes:
    ------
    This weight is depandant on split.
    """
    def __init__(self, data_dir, label_col, seq_id_col, mol_id_col, n_classes = 2, seq_min_count = 100, mol_min_count = 100, tested_enough_count = 100, seq_clip_min_broadness = 0.1, mol_clip_min_broadness = 0.1, auxiliary_data_path = None):
        name = None # 'weighted_Harmonic'
        super(CVPP_addWeights_BroadnessClipNoBias, self).__init__(name, data_dir)
        self.label_col = label_col
        self.seq_id_col = seq_id_col
        self.mol_id_col = mol_id_col
        self.n_classes = n_classes
        self.seq_min_count = seq_min_count
        self.mol_min_count = mol_min_count
        self.tested_enough_count = tested_enough_count
        self.seq_clip_min_broadness = seq_clip_min_broadness
        self.mol_clip_min_broadness = mol_clip_min_broadness
        self.col_name = 'broadness'
        self.weight_col_name = self.col_name + '_weight'
        self.auxiliary_data_path = auxiliary_data_path

    def serialize_hparams(self):
        """
        returns dictionary with all hyperparameters that will be saved. self.working_dir will be added
        to the dict in self.save_hparams.
        """
        return {'func_weight_name' : 'clip_no_bias',
                'label_col' : self.label_col,
                'seq_id_col' : self.seq_id_col,
                'mol_id_col' : self.mol_id_col,
                'seq_min_count' : self.seq_min_count,
                'mol_min_count' : self.mol_min_count,
                'tested_enough_count' : self.tested_enough_count,
                'seq_clip_min_broadness' : self.seq_clip_min_broadness,
                'mol_clip_min_broadness' : self.mol_clip_min_broadness,
                'weight_col_name' : self.weight_col_name,
                'auxiliary_data_path' : self.auxiliary_data_path}

    def load_auxiliary(self):
        auxiliary = {}
        auxiliary['seqs_broadness'] = pandas.read_csv(self.auxiliary_data_path['seqs_broadness'], sep=';', index_col = 0)
        auxiliary['mols_broadness'] = pandas.read_csv(self.auxiliary_data_path['mols_broadness'], sep=';', index_col = 0)
        return auxiliary
    
    @staticmethod
    def _preprocess_broadness(df, min_count, tested_enough_count, clip_min_broadness):
        """
        Apply necessary preprocessing to seqs_broadness and mols_broadness.
        """
        init_ordering = df.index
        tested_enough = df[df['count'] >= min_count].copy()
        not_tested_enough = df[df['count'] < min_count].copy()

        cols = df.columns[df.columns.str.contains('broadness')]

        default_broadness = tested_enough[cols].mean(axis = 0)

        if len(tested_enough) <= tested_enough_count:
            print('WARNING: Not enough examples to estimate default_broadness. Setting to class distribution instead.')
            for col in cols:
                default_broadness[col] = (df[col]*df['count']).sum()/df['count'].sum()
            print(default_broadness)
            # raise Exception('Is it a good idea to put class distribution here? Would it make more sense to ignore molecule broadness if we can not estimate it? (i.e. setting defualt_broadnes to 1/3 for all classes?)')

        if (default_broadness < clip_min_broadness).any():
            print(default_broadness)
            raise ValueError('broadness is lower than clip_min_broadness. This means that all not_tested_enough are clipped. Please double check values.')

        for col in cols:
            not_tested_enough[col] = default_broadness[col]
            tested_enough[col] = tested_enough[col].clip(lower = clip_min_broadness)
            not_tested_enough[col] = not_tested_enough[col].clip(lower = clip_min_broadness)

        df = pandas.concat([tested_enough, not_tested_enough])
        df = df.loc[init_ordering]
        return df

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        for i in range(self.n_classes):
            if x[self.label_col] == i:
                return (1/x['seq_broadness_cls' + str(i)])*(1/x['mol_broadness_cls' + str(i)])/self.n_classes
        raise ValueError('Unexpected value encountered in' + self.label_col + 'column.')

    def _postprocess(self, input_data, auxiliary):
        data = input_data.copy()
        if self.weight_col_name not in data.columns:
            new_columns = data.columns.to_list() + [self.weight_col_name]
        else:
            new_columns = data.columns.to_list()
        seqs_broadness = auxiliary['seqs_broadness'].copy()
        mols_broadness = auxiliary['mols_broadness'].copy()

        # seqs_broadness.rename(columns = {'broadness' : 'seq_broadness', 'count' : 'seq_count'}, inplace = True)
        # mols_broadness.rename(columns = {'broadness' : 'mol_broadness', 'count' : 'mol_count'}, inplace = True)
        seqs_broadness.columns = 'seq_' + seqs_broadness.columns
        mols_broadness.columns = 'mol_' + mols_broadness.columns

        data = data.join(seqs_broadness, on = self.seq_id_col, how = 'left')
        data = data.join(mols_broadness, on = self.mol_id_col, how = 'left')

        data[self.weight_col_name] = data.apply(self._weight_function, axis = 1)

        return data[new_columns]

    def postprocess(self):
        data_train, data_valid, data_test = self.load_data()
        auxiliary = self.load_auxiliary()

        auxiliary['seqs_broadness'] = self._preprocess_broadness(auxiliary['seqs_broadness'], min_count = self.seq_min_count, tested_enough_count = self.tested_enough_count, clip_min_broadness = self.seq_clip_min_broadness)
        auxiliary['mols_broadness'] = self._preprocess_broadness(auxiliary['mols_broadness'], min_count = self.mol_min_count, tested_enough_count = self.tested_enough_count, clip_min_broadness = self.mol_clip_min_broadness)

        if not data_train.empty:
            data_train = self._postprocess(data_train, auxiliary)
        data_train.to_csv(os.path.join(self.working_dir, 'data_train.csv'), sep=';', index = False, header = True)
        if not data_valid.empty:
            data_valid = self._postprocess(data_valid, auxiliary)
        data_valid.to_csv(os.path.join(self.working_dir, 'data_valid.csv'), sep=';', index = False, header = True)
        # if not data_test.empty:
        #     data_test = self._postprocess(data_test, auxiliary)
        data_test.to_csv(os.path.join(self.working_dir, 'data_test.csv'), sep=';', index = False, header = True)

        self.save_hparams(prefix = self.col_name + '_')
        return None



class CVPP_addWeights_SeqBroadnessClipNoBias(CVPP_addWeights_BroadnessClipNoBias):
    def __init__(self, data_dir, label_col, seq_id_col, mol_id_col, n_classes = 2, seq_min_count = 100, mol_min_count = 100, tested_enough_count = 100, seq_clip_min_broadness = 0.1, mol_clip_min_broadness = 0.1, auxiliary_data_path = None):
        super(CVPP_addWeights_SeqBroadnessClipNoBias, self).__init__(data_dir = data_dir,
                                                                  label_col = label_col, 
                                                                  seq_id_col = seq_id_col, 
                                                                  mol_id_col = mol_id_col, 
                                                                  n_classes = n_classes, 
                                                                  seq_min_count = seq_min_count, 
                                                                  mol_min_count = mol_min_count, 
                                                                  tested_enough_count = tested_enough_count, 
                                                                  seq_clip_min_broadness = seq_clip_min_broadness, 
                                                                  mol_clip_min_broadness = mol_clip_min_broadness, 
                                                                  auxiliary_data_path = auxiliary_data_path)
        self.col_name = 'seq_broadness'
        self.weight_col_name = self.col_name + '_weight'

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        for i in range(self.n_classes):
            if x[self.label_col] == i:
                return (1/x['seq_broadness_cls' + str(i)])/self.n_classes
        raise ValueError('Unexpected value encountered in' + self.label_col + 'column.')
    


class CVPP_addWeights_RootBroadnessClipNoBias(CVPP_addWeights_BroadnessClipNoBias):
    def __init__(self, data_dir, label_col, seq_id_col, mol_id_col, n_classes = 2, seq_min_count = 100, mol_min_count = 100, tested_enough_count = 100, seq_clip_min_broadness = 0.1, mol_clip_min_broadness = 0.1, auxiliary_data_path = None):
        super(CVPP_addWeights_RootBroadnessClipNoBias, self).__init__(data_dir = data_dir,
                                                                  label_col = label_col, 
                                                                  seq_id_col = seq_id_col, 
                                                                  mol_id_col = mol_id_col, 
                                                                  n_classes = n_classes, 
                                                                  seq_min_count = seq_min_count, 
                                                                  mol_min_count = mol_min_count, 
                                                                  tested_enough_count = tested_enough_count, 
                                                                  seq_clip_min_broadness = seq_clip_min_broadness, 
                                                                  mol_clip_min_broadness = mol_clip_min_broadness, 
                                                                  auxiliary_data_path = auxiliary_data_path)
        self.col_name = 'root_broadness'
        self.weight_col_name = self.col_name + '_weight'

    def _weight_function(self, x):
        """
        change k for scaling and numerical stability.
        """
        for i in range(self.n_classes):
            if x[self.label_col] == i:
                return numpy.sqrt((1/x['seq_broadness_cls' + str(i)])*(1/x['mol_broadness_cls' + str(i)]))/self.n_classes
        raise ValueError('Unexpected value encountered in' + self.label_col + 'column.')