import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
from pathlib import Path
from collections import defaultdict
import pickle
import seaborn as sns
from hapi_analysis import (
    generate_task_dataset_dict,
    generate_prediction_df,
    prediction_to_failure_df,
    generate_year_predictions,
    prediction_to_confidence_df,
    identify_change_over_time,
    YEARS
)
from plotting import plot_error_matrices_and_histograms, calculate_independent_pmf, count_combinations, sort_col_row

font = {"family": "sans-serif", "size": 16}
plt.rc('font', **font)
plt.rcParams['figure.dpi'] = 300


"""
Parameters:
- dir_name: Name of the directory where the graphs will be saved. Provided by user
- temporal: If true, generates matrix graphs for systems that experience a change in error rates over time. 
    If false, matrix graphs are generated for all systems, and each matrix graph will include predictions from all 3 years, so the same instance may be represented multiple times in a single graph
- task_list: List of tasks to generate graphs for. If None, will generate graphs for all tasks.
- only_single_label: If true, will only generate graphs for datasets where instances have a single ground truth label.
"""


def main(dir_name, temporal=False, task_list=None, only_single_label=False):
    task_dataset_dict = generate_task_dataset_dict(
        task_list=task_list, only_single_label=only_single_label
    )
    dir_path = Path("../results") / dir_name
    dir_path.mkdir(parents=True, exist_ok=True)
    if temporal:
        print("Generating matrix plots for years that exhibit change")
        generate_graphs_change_over_time(dir_path, task_dataset_dict)
    else:
        print("Generating all matrix plots...")
        generate_all_graphs(dir_path, task_dataset_dict)


"""
This function will generate relevant error graphs for datasets that experience a change in error rates over time. 

Parameters:
- dir_path: Path object to the directory where the graphs will be saved. Provided by user
- task_dataset_dict: Dictionary of tasks and datasets. 
- threshold: Threshold for how much the error rate must change to be considered a change. Default is 0.005.
"""


def generate_graphs_change_over_time(dir_path, task_dataset_dict, threshold=0.005):
    all_improvements_data = {}
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            dataset_dir = dir_path / task / dataset
            dataset_dir.mkdir(parents=True, exist_ok=True)
            (
                outcome_matrices,
                error_rates,
                error_changes_over_time,
            ) = identify_change_over_time(task, dataset, threshold=threshold)
            title = "Task: {}, Dataset: {}, Years: {}".format(
                task, dataset, list(outcome_matrices.keys())
            )
            if len(outcome_matrices) > 0:
                
                plot_error_matrices_and_histograms(
                    outcome_matrices, 
                    title,
                    display_title=False
                )
                plt.savefig(dataset_dir / title, bbox_inches='tight')

                err_title = "{} error rates".format(dataset)
                pd.DataFrame.from_dict(error_rates, orient="index").plot(
                        ylabel="error rate", xlabel="year", colormap='tab10'
                )
                plt.savefig(dataset_dir / err_title, bbox_inches='tight')
                
                plot_filtered_on_failures(outcome_matrices, error_changes_over_time, dataset_dir, all_improvements_data, (task, dataset))
                plt.close('all')
    results_title = dir_path / "improvements"
    with open(results_title.with_suffix(".pickle"), "wb") as handle:
        pickle.dump(all_improvements_data, handle)

def calculate_net_improvement_vs_expected(initial_year_failure_matrix, next_year_failure_matrix):
    model_ordering = None
    initial_year_failure_matrix, init_error_rates, model_ordering = sort_col_row(
        initial_year_failure_matrix, col_sorting_order=model_ordering
    )
    next_year_failure_matrix, next_error_rates, _ = sort_col_row(next_year_failure_matrix)

    init_unique_vals_str, _, init_counts, init_error_rates = count_combinations(initial_year_failure_matrix)
    next_unique_vals_str, _, next_counts, next_error_rates = count_combinations(next_year_failure_matrix, error_rates=init_error_rates)

    assert all(x == y for x, y in zip(init_unique_vals_str, next_unique_vals_str)), "Ensure ordering and elements are the same"
    print(init_unique_vals_str)
    print(init_error_rates)

    observed_diff_counts = np.array([next_count - init_count for next_count, init_count in zip(next_counts, init_counts)])
    net_improvement = np.sum(observed_diff_counts)
    observed_diff_proportions = observed_diff_counts / net_improvement

    reduction = np.sum(next_counts) / np.sum(init_counts)
    expected_diff_counts = np.array([init_count * (reduction - 1) for init_count in init_counts])
    expected_diff_proportions = expected_diff_counts / np.sum(expected_diff_counts)
    


    observed_net_improvement_proportions = {}
    expected_net_improvement_proportions = {}
    observed_net_improvement_counts = {}
    expected_net_improvement_counts = {}
    for i, unique_val in enumerate(init_unique_vals_str):
        observed_net_improvement_proportions[unique_val] = observed_diff_proportions[i]
        expected_net_improvement_proportions[unique_val] = expected_diff_proportions[i]

        observed_net_improvement_counts[unique_val] = -observed_diff_counts[i]
        expected_net_improvement_counts[unique_val] = -expected_diff_counts[i]

    return observed_net_improvement_proportions, observed_net_improvement_counts, expected_net_improvement_proportions, expected_net_improvement_counts

def calculate_net_improvement(better_instances, worse_instances):
    # Order models by error (ascending left to right)
    model_ordering = None
    better_instances, error_rates, model_ordering = sort_col_row(
        np.array(better_instances), col_sorting_order=model_ordering
    )
    worse_instances, _ , _ = sort_col_row(np.array(worse_instances))

    # Count frequency of each profile
    outcome_profiles_str, _, better_counts, _ = count_combinations(better_instances, error_rates=error_rates) 
    _ , _, worse_counts, _ = count_combinations(worse_instances, error_rates=error_rates)

    # For each profile subtract frequency of worse instances from frequency of better instances. 
    # For example, if the model improved on 50 systemic failures but got worse at 20 systemic failures, 
    #   we say it had a net improvement of 30 on systemic failures
    improvement_counts = np.array([better_count - worse_count for worse_count, better_count in zip(worse_counts, better_counts)])
    total_net_improvement = np.sum(improvement_counts)
    print(improvement_counts, total_net_improvement)
    observed_diff_proportions = improvement_counts / total_net_improvement
    
    # Save results into dict and return
    observed_net_improvement_proportions = {}
    observed_net_improvement_counts = {}
    for i, profile in enumerate(outcome_profiles_str):
        observed_net_improvement_proportions[profile] = observed_diff_proportions[i]
        observed_net_improvement_counts[profile] = improvement_counts[i]

    return total_net_improvement, observed_net_improvement_proportions, observed_net_improvement_counts


"""
This function generates the {improvements, decines} vs potential {improvements, declines} plots for each model that 
experiences a change in error rates over time. We only plot two years in a single graph, so if a model declines in 
between 2020 and 2021 and then between 2021 and 2022, separate plots will be generated.

If there is a net improvement in accuracy, we plot improvements vs potential improvements. If there is a net decline in accuracy, we plot declines vs potential declines.

Parameters:
- outcome_matrics: dictionary of form {year: outcome_matrix}
- error_changes_over_time: dictionary of form {model: {(year1, year2): change_in_error}}
- dataset_dir: directory to save plots and result_dict to
- results_dict: dictionary of form {(task, dataset, model, year, (potential) improvement/decline): {system level outcome: proportion of instances}}

Returns:
- None, but saves plots to dataset_dir

"""


def plot_filtered_on_failures(
    outcome_matrices, error_changes_over_time, dataset_dir, result_dict, task_dataset
):
    task, dataset = task_dataset
    for model, changes in error_changes_over_time.items():
        rejection_matrices = defaultdict(dict)
        for years, diff in changes.items():
            # We always calculate 'independent' and then append either 'better' or 'worse' depending on if change in accuracy was a net positive or negative
            # if diff < 0:  # improvement case
            failure_definitions = ["failures", "improvements", 'declines']
            # else:  # decline case
                # failure_definitions = ["failures", "declines"]
            this_year = years[0]
            next_year = years[1]

            this_year_outcomes = outcome_matrices[this_year]
            next_year_outcomes = outcome_matrices[next_year]
            print(model, years)

            for failure_definition in failure_definitions:
                if failure_definition == "failures":  # Failures in both years, independent of each other
                    this_year_focus_model_failures = this_year_outcomes[
                        this_year_outcomes[model] == 1
                    ].drop(model, axis=1)
                    next_year_focus_model_failures = next_year_outcomes[
                        next_year_outcomes[model] == 1
                    ].drop(model, axis=1)
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        this_year
                    ] = this_year_focus_model_failures
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        next_year
                    ] = next_year_focus_model_failures
                else:
                    if failure_definition == "improvements":  # Failure in this year; success in next year
                        filter_condition = (this_year_outcomes[model] == 1) & (
                            next_year_outcomes[model] == 0
                        )
                        reference_filter = this_year_outcomes[model] == 1
                    elif failure_definition == "declines":  # Success in this year; failure in next year
                        filter_condition = (this_year_outcomes[model] == 0) & (
                            next_year_outcomes[model] == 1
                        )
                        reference_filter = this_year_outcomes[model] == 0
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        failure_definition
                    ] = this_year_outcomes[filter_condition].drop(model, axis=1)
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        f"potential {failure_definition}"
                    ] = this_year_outcomes[reference_filter].drop(model, axis=1)

            ### Calculate net improvement 
            improved_instances = rejection_matrices[f"improvements_{this_year}"]['improvements']
            worse_instances = rejection_matrices[f"declines_{this_year}"]['declines']
            (total_net_improvement,
            observed_net_improvement_proportions, 
            observed_net_improvement_counts) = calculate_net_improvement(improved_instances, worse_instances)
            # total_net_improvement *= 1 #Display improvements as a positive number and declines as a negative (easier semantic interpretation)

            result_dict[(task, dataset, model, this_year, total_net_improvement, 'net improvements')] = observed_net_improvement_proportions
            result_dict[(task, dataset, model, this_year, total_net_improvement, 'net improvements count')] = observed_net_improvement_counts

            for experiment_information in rejection_matrices.keys(): #experiment_information composed of failure_definition and year
                failure_definition = experiment_information.split("_")[
                    0
                ]  # e.g. 'improvements', 'declines', 'failures'
                # year = experiment_information.split("_")[1] #e.g. 2020

                if failure_definition == "failures":
                    palette = "tab20c"
                elif failure_definition == "improvements":
                    palette = "tab20"
                else:
                    palette = ['darkorange', 'moccasin']
                title = f"{experiment_information} instances of {model} on {task} task, {dataset} dataset"
                fig, axs, proportions_data, counts_data = plot_error_matrices_and_histograms(
                        rejection_matrices[experiment_information], 
                        title,
                        additional_plot=None,
                        palette=palette,
                        plot_matrices=False,
                        display_title=False
                    )
                    
                for instance_subset_condition, proportions_dict in proportions_data.items(): #eg ('better', {'[0,1]:0.15, ...}))
                    if failure_definition == 'failures': # in this case, the instance_subset_condition  is the year
                        if instance_subset_condition == this_year:
                            instance_subset_condition = "initial year"
                        else:
                            instance_subset_condition = "next year"
                    results_key = (task, dataset, model, this_year, total_net_improvement, instance_subset_condition)
                    result_dict[results_key] = proportions_dict
                plt.savefig(dataset_dir / title, bbox_inches='tight')


"""
This function generates matrix plots for all datasets in Hapi repository. """


def generate_all_graphs(dir_path, task_dataset_dict):
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            R = generate_prediction_df(task, dataset)
            R_failures = prediction_to_failure_df(R)
            R_confidence = prediction_to_confidence_df(R)
            if R_failures is None:
                continue
            for year in YEARS:
                R_failures_np, error_rates = generate_year_predictions(R_failures, year)
                R_independent_draws = np.random.binomial(1, error_rates, R_failures_np.shape)
                title = f"Matrix Graphs: Task: {task}, Dataset: {dataset}"

                plot_error_matrices_and_histograms(
                    {"Observed": R_failures_np, "Independent": R_independent_draws}, title
                )
                matrix_graphs_dir = dir_path / "Matrix Graphs" / year
                matrix_graphs_dir.mkdir(parents=True, exist_ok=True)
                plt.savefig(matrix_graphs_dir /  title, bbox_inches='tight')

                # *** Generate confidence plots ***
                title = f'Confidence Graphs: Task: {task}, Dataset: {dataset}'
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
                palette = 'coolwarm_r'
                sns.violinplot(
                    data=R_confidence,
                    x=('is_correct', 'sum'),
                    y=('confidence', 'mean'),
                    scale = 'area',
                    inner='box',
                    ax = ax1,
                    palette=palette
                )
                sns.violinplot(
                    data=R_confidence,
                    x=('is_correct', 'sum'),
                    y=('confidence', 'std'),
                    scale = 'area',
                    inner='box',
                    ax = ax2,
                    palette=palette
                )
                fig.suptitle(title, fontsize=16)
                confidence_graphs_dir = dir_path / "Confidence Graphs" / year
                confidence_graphs_dir.mkdir(parents=True, exist_ok=True)
                plt.savefig(confidence_graphs_dir / title, bbox_inches='tight')


""" 
Unlike numpy.random.binomial(), this generates a matrix that represents the precise
independent reference distribution.  Concretely, this generates the datarame that would 
would be generated if each column were generated independently from the error rates of the columns in the provided df. 
"""
def generate_independent_df(df):    
    num_rows = df.shape[0]
    reference_df = []
    error_rates = df.apply(lambda x: x.sum() / x.count())
    pmf = calculate_independent_pmf(error_rates)
    for outcome, probability in pmf:
        unrounded = probability * num_rows
        num_outcome_rows = round(unrounded)
        # difference_in_real_and_rounded_prob = num_outcome_rows/num_rows - unrounded/num_rows
        # print(f'outcome: {outcome}, unrounded: {unrounded}, rounded: {num_outcome_rows}, rounded - real: {difference_in_real_and_rounded_prob}')
        outcome_draws = [outcome] * num_outcome_rows
        reference_df.extend(outcome_draws)
    return pd.DataFrame(reference_df)



"""
The script can be called as such: "python3 hapi_matrix_graphs.py {dir_name}.
A results folder will be created with dir_name in the results folder. """
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dir_name", help="name to be used for naming results folders")
    parser.add_argument("--temporal", action="store_true")
    parser.add_argument("--single_label", action="store_true")
    parser.add_argument("--task", action="append")
    args = parser.parse_args()

    main(
        args.dir_name,
        temporal=args.temporal,
        task_list=args.task,
        only_single_label=args.single_label,
    )
    print("All results generated")
