
import numpy as np

import cfg


class Result:
    def __init__(self, y_true, y_pred, df_al, sampling_method, random_seed, processing_time,
                 processing_time_incl_training, dataset, init_train_samples, add_train_samples, max_train_samples):
        self.dataset = dataset
        self.init_train_samples = init_train_samples
        self.add_train_samples = add_train_samples
        self.max_train_samples = max_train_samples
        self.sampling_method = sampling_method
        self.random_seed = random_seed
        self.processing_time = processing_time
        self.processing_time_incl_training = processing_time_incl_training

        self.nr_classes = np.shape(y_pred)[2]
        self.samples_per_class = np.sum(y_true, axis=0)

        # selected samples for training
        (self.nr_training_samples,
         self.nr_events_in_training_samples,
         self.selected_training_sample_indices_by_al) = _get_selected_training_samples(df_al, y_true)

        # get selected evaluation set
        self.evaluation_indices = df_al[df_al[cfg.get_iteration_col(0)] == cfg.tag_evaluate].index.to_list()

        ######################
        # evaluation results #
        ######################
        (self.evaluation_tp,
         self.evaluation_tn,
         self.evaluation_fp,
         self.evaluation_fn,
         self.evaluation_precision_micro,
         self.evaluation_precision_macro,
         self.evaluation_recall_micro,
         self.evaluation_recall_macro,
         self.evaluation_f1_micro,
         self.evaluation_f1_macro) = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_evaluate)

        # ceiling results
        (self.evaluation_ceiling_tp,
         self.evaluation_ceiling_tn,
         self.evaluation_ceiling_fp,
         self.evaluation_ceiling_fn,
         self.evaluation_ceiling_precision_micro,
         self.evaluation_ceiling_precision_macro,
         self.evaluation_ceiling_recall_micro,
         self.evaluation_ceiling_recall_macro,
         self.evaluation_ceiling_f1_micro,
         self.evaluation_ceiling_f1_macro) \
            = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_evaluate, ceiling=True)

        ####################
        # training results #
        ####################
        (self.training_tp,
         self.training_tn,
         self.training_fp,
         self.training_fn,
         self.training_precision_micro,
         self.training_precision_macro,
         self.training_recall_micro,
         self.training_recall_macro,
         self.training_f1_micro,
         self.training_f1_macro) = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_train)

        # ceiling results
        (self.training_ceiling_tp,
         self.training_ceiling_tn,
         self.training_ceiling_fp,
         self.training_ceiling_fn,
         self.training_ceiling_precision_micro,
         self.training_ceiling_precision_macro,
         self.training_ceiling_recall_micro,
         self.training_ceiling_recall_macro,
         self.training_ceiling_f1_micro,
         self.training_ceiling_f1_macro) \
            = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_train, ceiling=True)

        ######################
        # validation results #
        ######################
        (self.validation_tp,
         self.validation_tn,
         self.validation_fp,
         self.validation_fn,
         self.validation_precision_micro,
         self.validation_precision_macro,
         self.validation_recall_micro,
         self.validation_recall_macro,
         self.validation_f1_micro,
         self.validation_f1_macro) = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_validate)

        # ceiling results
        (self.validation_ceiling_tp,
         self.validation_ceiling_tn,
         self.validation_ceiling_fp,
         self.validation_ceiling_fn,
         self.validation_ceiling_precision_micro,
         self.validation_ceiling_precision_macro,
         self.validation_ceiling_recall_micro,
         self.validation_ceiling_recall_macro,
         self.validation_ceiling_f1_micro,
         self.validation_ceiling_f1_macro) \
            = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_validate, ceiling=True)

        ######################
        # unlabelled results #
        ######################
        (self.unlabelled_tp,
         self.unlabelled_tn,
         self.unlabelled_fp,
         self.unlabelled_fn,
         self.unlabelled_precision_micro,
         self.unlabelled_precision_macro,
         self.unlabelled_recall_micro,
         self.unlabelled_recall_macro,
         self.unlabelled_f1_micro,
         self.unlabelled_f1_macro) = _get_fscore_parts_with_tag(df_al, y_true, y_pred, cfg.tag_unlabelled)


def _get_selected_training_samples(df_al, y_true):
    # init output lists
    nr_training_samples = []
    nr_events_in_training_samples = []
    selected_samples_by_al = []
    # init iteration
    iteration = 0
    # iterate over all tag_columns
    while True:
        # get tag column
        tag_col = cfg.get_iteration_col(iteration)
        # check if iteration exists
        if tag_col not in df_al.columns:
            break
        # get all training indices from current column
        current_indices = (
            df_al[(df_al[tag_col] == cfg.tag_train) | (df_al[tag_col] == cfg.tag_validate)].index.to_list())

        # only save new selected training samples
        added_indices = \
            [item for item in current_indices if not any(item in sublist for sublist in selected_samples_by_al)]

        # get the positive samples for each class
        positive_samples_iter = []
        for class_index in range(np.shape(y_true)[1]):
            relevant_items = y_true[current_indices, class_index]
            nr_relevant_items = sum(relevant_items)
            positive_samples_iter.append(nr_relevant_items)

        # append results
        nr_training_samples.append(len(current_indices))
        nr_events_in_training_samples.append(positive_samples_iter)
        selected_samples_by_al.append(added_indices)

        # increase iteration
        iteration = iteration+1

    return nr_training_samples, nr_events_in_training_samples, selected_samples_by_al


def _get_fscore_parts_with_tag(df_al, y_true, y_pred, tag, ceiling=False):
    # create a list with indices for each iteration with the training sample indices
    if ceiling:
        iteration = 'all'
    else:
        iteration = 0
    y_true_tag = []
    y_pred_tag = []
    while True:
        # get tag column
        tag_col = cfg.get_iteration_col(iteration)
        # check if iteration exists
        if tag_col not in df_al.columns:
            break
        # get all training indices from current column
        current_indices = df_al[df_al[tag_col] == tag].index.to_list()

        # get y_true slice
        y_true_iteration = y_true[current_indices, :]
        y_true_tag.append(y_true_iteration)

        # get y_pred slice
        if type(iteration).__name__ != 'int':
            y_pred_iteration = y_pred[-1, current_indices, :]
            y_pred_tag.append(y_pred_iteration)
            break

        y_pred_iteration = y_pred[iteration, current_indices, :]
        y_pred_tag.append(y_pred_iteration)

        # increase iteration
        iteration = iteration + 1

    return _get_fscore_parts(y_true_tag, y_pred_tag)


def _get_fscore_parts(y_true, y_pred):
    # np.shape(y_true/y_pred) = (nr_iterations, nr_samples, nr_classes)
    tp = []
    tn = []
    fp = []
    fn = []
    precision_micro = []
    precision_macro = []
    recall_micro = []
    recall_macro = []
    f1_micro = []
    f1_macro = []

    for iteration in range(len(y_true)):
        # get y_true and y_pred from current iteration
        y_true_iter = y_true[iteration]
        y_pred_iter = y_pred[iteration]
        y_pred_iter[y_pred_iter < 0.5] = 0
        y_pred_iter[y_pred_iter >= 0.5] = 1

        # compute tp, tn, fp, fn
        tp_iter = (y_true_iter == 1) & (y_pred_iter == 1)
        tp_iter = tp_iter.astype(int)
        tp_iter = np.sum(tp_iter, axis=0)
        tp_iter = np.array(tp_iter)

        tn_iter = (y_true_iter == 0) & (y_pred_iter == 0)
        tn_iter = tn_iter.astype(int)
        tn_iter = np.sum(tn_iter, axis=0)
        tn_iter = np.array(tn_iter)

        fp_iter = (y_true_iter == 0) & (y_pred_iter == 1)
        fp_iter = fp_iter.astype(int)
        fp_iter = np.sum(fp_iter, axis=0)
        fp_iter = np.array(fp_iter)

        fn_iter = (y_true_iter == 1) & (y_pred_iter == 0)
        fn_iter = fn_iter.astype(int)
        fn_iter = np.sum(fn_iter, axis=0)
        fn_iter = np.array(fn_iter)

        # micro combination
        tp_micro = np.sum(tp_iter)
        fp_micro = np.sum(fp_iter)
        fn_micro = np.sum(fn_iter)

        precision_micro_iter = tp_micro / (tp_micro + fp_micro)
        recall_micro_iter = tp_micro / (tp_micro + fn_micro)
        f1_micro_iter = 2 * tp_micro / (2 * tp_micro + fp_micro + fn_micro)

        # macro combination
        precision_macro_iter = np.nanmean(tp_iter / (tp_iter + fp_iter))
        recall_macro_iter = np.nanmean(tp_iter / (tp_iter + fn_iter))
        f1_macro_iter = 2 * precision_macro_iter * recall_macro_iter / (precision_macro_iter + recall_macro_iter)

        # save results
        tp.append(tp_iter)
        tn.append(tn_iter)
        fp.append(fp_iter)
        fn.append(fn_iter)

        precision_micro.append(precision_micro_iter)
        recall_micro.append(recall_micro_iter)
        f1_micro.append(f1_micro_iter)

        precision_macro.append(precision_macro_iter)
        recall_macro.append(recall_macro_iter)
        f1_macro.append(f1_macro_iter)

    return tp, tn, fp, fn, precision_micro, precision_macro, recall_micro, recall_macro, f1_micro, f1_macro
