import numpy as np
from xgboost import XGBClassifier
from denoiser.loss_functions import total_variation_loss, mean_squared_error_loss, jensen_shannon_divergence
from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
from query import query_marginal
from constraints import ConstraintCompiler, ConstraintEvaluator


def statistics(arr):
    """
    Takes a list or an array and return the mean, std, median, min, and max statistics.

    :param arr: (list or np.ndarray) The data list or array.
    :return: (tuple) Of the mean, std, median, min, and max of the input list or array.
    """
    return np.mean(arr), np.std(arr), np.median(arr), np.min(arr), np.max(arr)


def evaluate_sampled_dataset(synthetic_dataset, workload, true_measured_workload, dataset, max_slice, random_seed):
    """
    Takes the synthetic data and the measured marginals on the true data, and returns an evaluation of all the marginal 
    errors on the workload, and the performance statistics of an xgboost.

    :param synthetic_dataset: (torch.tensor) The generated synthetic dataset in full one hot encoding.
    :param workload: (list) The workload as list of tuples.
    :param true_measured_workload: (dict) The true workload measurements.
    :param dataset: (BaseDataset) The instantiated dataset object.
    :param max_slice: (int) Max size for marginal computations.
    :param random_seed: (int) Random seed for reproducibility.
    :return: (tuple) TV error, L2 error, JS error, XGB accuracy, XGB balanced accuracy, and XGB F1 score statistics.
    """
    # measure all marginals on it
    all_measured_fake_marginals = {m: query_marginal(synthetic_dataset, m, dataset.full_one_hot_index_map, normalize=True, input_torch=True, max_slice=max_slice) for m in workload}

    all_tv_errors = [total_variation_loss(true_measured_workload[m], all_measured_fake_marginals[m]).item() for m in workload]
    all_l2_errors = [mean_squared_error_loss(true_measured_workload[m], all_measured_fake_marginals[m]).item() for m in workload]
    all_js_errors = [jensen_shannon_divergence(true_measured_workload[m], all_measured_fake_marginals[m]).item() for m in workload]

    # train an XGBoost
    Xtest, ytest = ConstraintCompiler.prepare_data(dataset.get_Dtest_full_one_hot(return_torch=True), list(dataset.train_features.keys()), dataset.label, dataset)
    Xtrain_synth, ytrain_synth = ConstraintCompiler.prepare_data(synthetic_dataset.cpu(), list(dataset.train_features.keys()), dataset.label, dataset)

    # avoid encoding error
    ytrain_synth = ConstraintEvaluator.handle_missing_classes_in_training_data(ytrain_synth, dataset.features[dataset.label])
    
    xgb_synth = XGBClassifier(verbosity=0, random_state=random_seed)
    xgb_synth.fit(Xtrain_synth.cpu().numpy(), ytrain_synth.cpu().numpy().astype(int))
    pred_synth = xgb_synth.predict(Xtest.cpu().numpy())
    
    acc, bac, f1 = accuracy_score(ytest.cpu().numpy(), pred_synth), balanced_accuracy_score(ytest.cpu().numpy(), pred_synth), f1_score(ytest.cpu().numpy(), pred_synth, average='micro')

    return statistics(all_tv_errors), statistics(all_l2_errors), statistics(all_js_errors), statistics([acc]), statistics([bac]), statistics([f1])
