from .data import load_instance
from .utils import get_nll
import numpy as np
import os

# # # CALIBRATION # # #

def get_calibration_plot_data(algorithms, datasets):
    results = {}
    for dataset in datasets:
        instance = load_instance(dataset)
        results[dataset] = {'instance' : instance}
        for algorithm_name, algorithm in algorithms.items():
            results[dataset][algorithm_name] = algorithm(instance)
    return results

def load_calibration_results(filename, algorithms, datasets):
    results = {dataset : {algo_name : [] for algo_name in algorithms} for dataset in datasets}
    with open(filename, 'r') as f:
        for line in f:
            saved = eval(line)
            for algo_name in algorithms:
                if algo_name == 'dataset': continue
                results[saved['dataset']][algo_name].append(saved[algo_name])
    return results

def get_calibration_table_data(is_binary, algorithms, datasets, num_runs=10, folder=None):
    if folder is not None:
        type = 'binary' if is_binary else 'regression'
        filename = folder + f'/calibration_{type}.csv'
        if os.path.exists(filename):
            return load_calibration_results(filename, algorithms, datasets) 
    results = {}

    for dataset in datasets:
        results[dataset] = {algo_name : [] for algo_name in algorithms}
        for _ in range(num_runs):
            instance = load_instance(dataset)
            for algo_name, algorithm in algorithms.items():
                output = algorithm(instance)
                nll = get_nll(output['pred'], output['std'], instance['y_test'], is_binary).mean()
                results[dataset][algo_name].append(nll)
            if folder is not None:
                saved = {algo_name : results[dataset][algo_name][-1] for algo_name in algorithms}
                saved['dataset'] = dataset
                with open(filename, 'a') as f:
                    if 'nan' not in str(saved):
                        f.write(str(saved) + '\n')

    if folder is not None:
        results = load_calibration_results(filename, algorithms, datasets)

    return results

# # # CONSISTENCY # # #

def get_consistency_results(filename, algorithms, parameter_name, parameter):
    results = {name: {} for name in algorithms}
    observed_parameters = set()
    with open(filename, 'r') as f:
        for line in f:
            saved = eval(line)
            parameter = saved[parameter_name]
            for name in saved:
                if name == parameter_name: continue
                if parameter not in results[name]:
                    results[name][parameter] = []
                results[name][parameter] += [saved[name]]
            observed_parameters.add(parameter)
    return results, observed_parameters

def run_consistency_experiment(instance, algorithms, parameter_name, parameters, filename):
    for parameter in parameters: 
        assert parameter_name in instance
        instance[parameter_name] = parameter
        saved = {parameter_name: parameter}
        for algo_name, algorithm in algorithms.items():
            output = algorithm(instance)
            saved[algo_name] = list(output['std'])
        if len(saved) > 1:
            with open(filename, 'a') as f:
                f.write(str(saved) + '\n')

def get_consistency_data(dataset, algorithms, parameter_name, parameters):
    filename = f'cached/consistency_{parameter_name}_{dataset}.csv'
    instance = load_instance(dataset)
    if not os.path.exists(filename):
        run_consistency_experiment(instance, algorithms, parameter_name, parameters, filename)
    results, observed_parameters = get_consistency_results(filename, algorithms, parameter_name, parameters)
    if set(parameters) != set(observed_parameters):
        print(f'Warning: Missing {parameter_name} for {dataset}')
    return results

# # # BINARY FAIRNESS # # #

def compute_binary_fairness(results, algorithms, metrics): 
    metric_values = {metric : {algo_name : [] for algo_name in algorithms} for metric in metrics}
    metric_values['Included \%'] = {algo_name : [] for algo_name in algorithms}

    for num_run in results:
        instance = results[num_run]['instance']
        y = instance['y_test']
        group = instance['group_test'] == 1
        for algo_name in algorithms:
            pred = results[num_run][algo_name]['pred']
            # Gridsearch returns floats
            pred = pred.clip(0, 1)
            pred = np.round(pred)
            assert np.all(np.unique(pred)==np.array([0,1])), 'Predictions are not binary'
            # No abstaining
            if 'std' not in results[num_run][algo_name]:
                for metric in metrics:
                    val = metrics[metric](pred, y, group, run_checks=False)
                    metric_values[metric][algo_name].append(val)
                metric_values['Included \%'][algo_name].append(100)
                continue
            # Abstaining at best percentile for SP
            percentiles = np.arange(75, 101, 1)
            std = results[num_run][algo_name]['std'] 
            best_val = np.inf
            for percentile in percentiles:
                include = std <= np.percentile(std, percentile)
                val = 0
                for metric in metrics:
                    normalization = metrics[metric](results[num_run]['Baseline']['pred'], y, group, run_checks=False)
                    metric_val = metrics[metric](pred[include], y[include], group[include], run_checks=False)
                    val += metric_val / normalization
                if val < best_val:
                    best_val = val
                    best_percentile = percentile
            include = std <= np.percentile(std, best_percentile) 
            for metric in metrics:
                metric_val = metrics[metric](pred[include], y[include], group[include], run_checks=False)
                metric_values[metric][algo_name].append(metric_val)
            metric_values['Included \%'][algo_name].append(int(100 * include.mean()))

    return metric_values

def compute_binary_fairness_extended(abstention_method, results, algorithms, metrics): 
    all_algorithms = list(algorithms.keys()) + [abstention_method]
    metric_values = {metric : {algo_name : [] for algo_name in all_algorithms} for metric in metrics}
    metric_values['Included \%'] = {algo_name : [] for algo_name in all_algorithms}

    for num_run in results:
        instance = results[num_run]['instance']
        y = instance['y_test']
        group = instance['group_test'] == 1
        pred = results[num_run][abstention_method]['pred']
        pred = pred.clip(0, 1)
        pred = np.round(pred)
        percentiles = np.arange(75, 101, 1)
        std = results[num_run][abstention_method]['std'] 
        best_val = np.inf
        # Look for best abstention percentile
        for percentile in percentiles:
            include = std <= np.percentile(std, percentile)
            val = 0
            for metric in metrics:
                normalization = metrics[metric](results[num_run]['Baseline']['pred'], y, group, run_checks=False)
                metric_val = metrics[metric](pred[include], y[include], group[include], run_checks=False)
                val += metric_val / normalization
            if val < best_val:
                best_val = val
                best_percentile = percentile
        # Choose best
        include = std <= np.percentile(std, best_percentile) 
        for metric in metrics:
            metric_val = metrics[metric](pred[include], y[include], group[include], run_checks=False)
            metric_values[metric][abstention_method].append(metric_val)
        include_rate = int(100 * include.mean())
        metric_values['Included \%'][abstention_method].append(include_rate)
        for algo_name in algorithms:
            pred = results[num_run][algo_name]['pred']
            # Gridsearch returns floats
            pred = pred.clip(0, 1)
            pred = np.round(pred)
            assert np.all(np.unique(pred)==np.array([0,1])), 'Predictions are not binary'
            # No abstaining
            if 'std' not in results[num_run][algo_name]:
                for metric in metrics:
                    val = metrics[metric](pred[include], y[include], group[include], run_checks=False)
                    metric_values[metric][algo_name].append(val)
                metric_values['Included \%'][algo_name].append(include_rate)

    return metric_values