import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import argparse
from pathlib import Path
from collections import defaultdict
import pickle
import seaborn as sns
from hapi_analysis import (
    generate_task_dataset_dict,
    generate_prediction_df,
    prediction_to_failure_df,
    generate_year_predictions,
    prediction_to_confidence_df,
    identify_change_over_time,
    YEARS,
)
from plotting import (
    plot_error_matrices_and_histograms,
    calculate_independent_pmf,
    count_combinations,
    sort_col_row,
)

font = {"family": "sans-serif", "size": 16}
plt.rc("font", **font)
plt.rcParams["figure.dpi"] = 300


"""
Parameters:
- dir_name: Name of the directory where the graphs will be saved. Provided by user
- Improvements: If true, generates matrix graphs for systems that experience a change in error rates over time. 
    If false, matrix graphs are generated for all systems, and each matrix graph will include predictions from all 3 years, so the same instance may be represented multiple times in a single graph
- task_list: List of tasks to generate graphs for. If None, will generate graphs for all tasks.
- only_single_label: If true, will only generate graphs for datasets where instances have a single ground truth label.
"""


def main(dir_name, improvements=True, task_list=None, only_single_label=True):
    task_dataset_dict = generate_task_dataset_dict(
        task_list=task_list, only_single_label=only_single_label
    )
    dir_path = Path("../results") / dir_name / "improvements"
    dir_path.mkdir(parents=True, exist_ok=True)
    if improvements:
        print("Generating results for periods of model improvement")
        generate_graphs_change_over_time(dir_path, task_dataset_dict)
    else:
        print("Generating all outcome profile plots for all datasets")
        generate_all_graphs(dir_path, task_dataset_dict)


"""
This function will generate relevant error graphs for datasets that experience a change in error rates over time. 

Parameters:
- dir_path: Path object to the directory where the graphs will be saved. Provided by user
- task_dataset_dict: Dictionary of tasks and datasets. 
- threshold: Threshold for how much the error rate must change to be considered a change. Default is 0.001.
"""


def generate_graphs_change_over_time(dir_path, task_dataset_dict, threshold=0.001):
    all_improvements_data = {}
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            dataset_dir = dir_path / task / dataset
            dataset_dir.mkdir(parents=True, exist_ok=True)
            (
                outcome_matrices,
                error_rates,
                error_changes_over_time,
            ) = identify_change_over_time(task, dataset, threshold=threshold)
            title = "Task: {}, Dataset: {}, Years: {}".format(
                task, dataset, list(outcome_matrices.keys())
            )
            arxiv_title = title.replace(" ", "_")
            if len(outcome_matrices) > 0:

                plot_error_matrices_and_histograms(
                    outcome_matrices, title, display_title=False
                )
                plt.savefig(dataset_dir / arxiv_title, bbox_inches="tight")

                err_title = "{}_error_rates".format(dataset)
                pd.DataFrame.from_dict(error_rates, orient="index").plot(
                    ylabel="error rate", xlabel="year", colormap="tab10"
                )
                plt.savefig(dataset_dir / err_title, bbox_inches="tight")

                plot_filtered_on_failures(
                    outcome_matrices,
                    error_changes_over_time,
                    dataset_dir,
                    all_improvements_data,
                    (task, dataset),
                )
                plt.close("all")
    results_title = dir_path / "improvements"
    with open(results_title.with_suffix(".pickle"), "wb") as handle:
        pickle.dump(all_improvements_data, handle)


def calculate_net_improvement_vs_expected(
    initial_year_failure_matrix, next_year_failure_matrix
):
    model_ordering = None
    initial_year_failure_matrix, init_error_rates, model_ordering = sort_col_row(
        initial_year_failure_matrix, col_sorting_order=model_ordering
    )
    next_year_failure_matrix, next_error_rates, _ = sort_col_row(
        next_year_failure_matrix, col_sorting_order=model_ordering
    )

    init_unique_vals_str, _, init_counts, init_error_rates = count_combinations(
        initial_year_failure_matrix
    )
    next_unique_vals_str, _, next_counts, next_error_rates = count_combinations(
        next_year_failure_matrix, error_rates=init_error_rates
    )

    assert all(
        x == y for x, y in zip(init_unique_vals_str, next_unique_vals_str)
    ), "Ensure ordering and elements are the same"

    observed_diff_counts = np.array(
        [
            next_count - init_count
            for next_count, init_count in zip(next_counts, init_counts)
        ]
    )
    net_improvement = np.sum(observed_diff_counts)
    observed_diff_proportions = observed_diff_counts / net_improvement

    reduction = np.sum(next_counts) / np.sum(init_counts)
    expected_diff_counts = np.array(
        [init_count * (reduction - 1) for init_count in init_counts]
    )
    expected_diff_proportions = expected_diff_counts / np.sum(expected_diff_counts)

    observed_net_improvement_proportions = {}
    expected_net_improvement_proportions = {}
    observed_net_improvement_counts = {}
    expected_net_improvement_counts = {}
    for i, unique_val in enumerate(init_unique_vals_str):
        observed_net_improvement_proportions[unique_val] = observed_diff_proportions[i]
        expected_net_improvement_proportions[unique_val] = expected_diff_proportions[i]

        observed_net_improvement_counts[unique_val] = -observed_diff_counts[i]
        expected_net_improvement_counts[unique_val] = -expected_diff_counts[i]

    return (
        observed_net_improvement_proportions,
        observed_net_improvement_counts,
        expected_net_improvement_proportions,
        expected_net_improvement_counts,
    )


"""
Given the failure matrix of better_instances and worse_instances, calcualtes the net improvement for each outcome profile (of the other models)
Parameters:
- better_instances (dataframe or np.array): Failure matrix of other models for instances that a model failed at in the first year and succeeded at in the second year
- worse_instances (dataframe or np.array): Failure matrix of other models for instances that a model succeeded at in the first year and failed at in the second year

Returns:
- total_net_improvement: the total number of net improvements that a model makes
- observed_net_improvement_proportions: proportion of net improvements for each outcome profile
- observed_net_improvement_count: count of net improvements for each outcome profile.
"""


def calculate_net_improvement(better_instances, worse_instances):
    # Order models by error (ascending left to right)
    model_ordering = None
    better_instances, error_rates, model_ordering = sort_col_row(
        np.array(better_instances), col_sorting_order=model_ordering
    )
    worse_instances, _, _ = sort_col_row(np.array(worse_instances))

    # Count frequency of each profile
    outcome_profiles_str, _, better_counts, _ = count_combinations(
        better_instances, error_rates=error_rates
    )
    _, _, worse_counts, _ = count_combinations(worse_instances, error_rates=error_rates)

    # For each profile subtract frequency of worse instances from frequency of better instances.
    # For example, if the model improved on 50 systemic failures but got worse at 20 systemic failures,
    #   we say it had a net improvement of 30 on systemic failures
    improvement_counts = np.array(
        [
            better_count - worse_count
            for worse_count, better_count in zip(worse_counts, better_counts)
        ]
    )
    total_net_improvement = np.sum(improvement_counts)
    observed_diff_proportions = improvement_counts / total_net_improvement

    # Save results into dict and return
    observed_net_improvement_proportions = {}
    observed_net_improvement_counts = {}
    for i, profile in enumerate(outcome_profiles_str):
        observed_net_improvement_proportions[profile] = observed_diff_proportions[i]
        observed_net_improvement_counts[profile] = improvement_counts[i]

    return (
        total_net_improvement,
        observed_net_improvement_proportions,
        observed_net_improvement_counts,
    )


"""
This function generates the {improvements, decines} vs potential {improvements, declines} plots for each model that 
experiences a change in error rates over time. We only plot two years in a single graph, so if a model declines in 
between 2020 and 2021 and then between 2021 and 2022, separate plots will be generated.

Parameters:
- outcome_matrics: dictionary of form {year: outcome_matrix}
- error_changes_over_time: dictionary of form {model: {(year1, year2): change_in_error}}
- dataset_dir: directory to save plots and result_dict to
- results_dict: dictionary of form {(task, dataset, model, year, (potential) improvement/decline): {system level outcome: proportion of instances}}
- task_dataset: tuple of form (task, dataset)

Returns:
- None, but saves plots to dataset_dir

"""


def plot_filtered_on_failures(
    outcome_matrices, error_changes_over_time, dataset_dir, result_dict, task_dataset
):
    task, dataset = task_dataset
    for model, changes in error_changes_over_time.items():
        rejection_matrices = defaultdict(dict)
        for years, diff in changes.items(): #diff is change in error so negative value is improvement
            print(f"{task} {dataset} change in error of {diff}")
            failure_definitions = ["failures", "improvements", "declines"]
            this_year = years[0]
            next_year = years[1]

            this_year_outcomes = outcome_matrices[this_year]
            next_year_outcomes = outcome_matrices[next_year]

            for failure_definition in failure_definitions:
                if (
                    failure_definition == "failures"
                ):  # Failures in both years, independent of each other
                    this_year_focus_model_failures = this_year_outcomes[
                        this_year_outcomes[model] == 1
                    ].drop(model, axis=1)
                    next_year_focus_model_failures = next_year_outcomes[
                        next_year_outcomes[model] == 1
                    ].drop(model, axis=1)
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        this_year
                    ] = this_year_focus_model_failures
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        next_year
                    ] = next_year_focus_model_failures
                else:
                    if (
                        failure_definition == "improvements"
                    ):  # Failure in this year; success in next year
                        filter_condition = (this_year_outcomes[model] == 1) & (
                            next_year_outcomes[model] == 0
                        )
                        reference_filter = this_year_outcomes[model] == 1
                    elif (
                        failure_definition == "declines"
                    ):  # Success in this year; failure in next year
                        filter_condition = (this_year_outcomes[model] == 0) & (
                            next_year_outcomes[model] == 1
                        )
                        reference_filter = this_year_outcomes[model] == 0
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        failure_definition
                    ] = this_year_outcomes[filter_condition].drop(model, axis=1)
                    rejection_matrices[f"{failure_definition}_{this_year}"][
                        f"potential {failure_definition}"
                    ] = this_year_outcomes[reference_filter].drop(model, axis=1)

            ### Calculate net improvement
            improved_instances = rejection_matrices[f"improvements_{this_year}"][
                "improvements"
            ]
            worse_instances = rejection_matrices[f"declines_{this_year}"]["declines"]
            (
                total_net_improvement,
                observed_net_improvement_proportions,
                observed_net_improvement_counts,
            ) = calculate_net_improvement(improved_instances, worse_instances)

            result_dict[
                (
                    task,
                    dataset,
                    model,
                    this_year,
                    diff,
                    total_net_improvement,
                    "net improvements",
                )
            ] = observed_net_improvement_proportions
            result_dict[
                (
                    task,
                    dataset,
                    model,
                    this_year,
                    diff,
                    total_net_improvement,
                    "net improvements count",
                )
            ] = observed_net_improvement_counts

            for (
                experiment_information
            ) in (
                rejection_matrices.keys()
            ):  # experiment_information composed of failure_definition and year
                failure_definition = experiment_information.split("_")[
                    0
                ]  # e.g. 'improvements', 'declines', 'failures'
                # year = experiment_information.split("_")[1] #e.g. 2020

                if failure_definition == "failures":
                    palette = "tab20c"
                elif failure_definition == "improvements":
                    palette = "tab20"
                else:
                    palette = ["darkorange", "moccasin"]
                title = f"{experiment_information} instances of {model} on {task} task, {dataset} dataset"
                (
                    fig,
                    axs,
                    proportions_data,
                    counts_data,
                ) = plot_error_matrices_and_histograms(
                    rejection_matrices[experiment_information],
                    title,
                    additional_plot=None,
                    palette=palette,
                    plot_matrices=False,
                    display_title=False,
                )

                for (
                    instance_subset_condition,
                    proportions_dict,
                ) in proportions_data.items():  # eg ('better', {'[0,1]:0.15, ...}))
                    if (
                        failure_definition == "failures"
                    ):  # in this case, the instance_subset_condition  is the year
                        if instance_subset_condition == this_year:
                            instance_subset_condition = "initial year"
                        else:
                            instance_subset_condition = "next year"
                    results_key = (
                        task,
                        dataset,
                        model,
                        this_year,
                        diff,
                        total_net_improvement,
                        instance_subset_condition,
                    )
                    result_dict[results_key] = proportions_dict
                arxiv_title = title.replace(" ", "_")
                plt.savefig(dataset_dir / arxiv_title, bbox_inches="tight")
                plt.close("all")


"""
This function generates matrix plots for all datasets in Hapi repository. """


def generate_all_graphs(dir_path, task_dataset_dict):
    for task in task_dataset_dict:
        print("*** TASK: {} ***".format(task))
        for dataset in task_dataset_dict[task]:
            print("*** DATASET: {} ***".format(dataset))
            R = generate_prediction_df(task, dataset)
            R_failures = prediction_to_failure_df(R)
            R_confidence = prediction_to_confidence_df(R)
            if R_failures is None:
                continue
            for year in YEARS:
                R_failures_np, error_rates = generate_year_predictions(R_failures, year)
                R_independent_draws = np.random.binomial(
                    1, error_rates, R_failures_np.shape
                )
                title = f"Matrix Graphs: Task: {task}, Dataset: {dataset}"

                plot_error_matrices_and_histograms(
                    {"Observed": R_failures_np, "Independent": R_independent_draws},
                    title,
                )
                matrix_graphs_dir = dir_path / "Matrix Graphs" / year
                matrix_graphs_dir.mkdir(parents=True, exist_ok=True)
                arxiv_title = title.replace(" ", "_")
                plt.savefig(matrix_graphs_dir / arxiv_title, bbox_inches="tight")

                # *** Generate confidence plots ***
                title = f"Confidence Graphs: Task: {task}, Dataset: {dataset}"
                fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10))
                palette = "coolwarm_r"
                sns.violinplot(
                    data=R_confidence,
                    x=("is_correct", "sum"),
                    y=("confidence", "mean"),
                    scale="area",
                    inner="box",
                    ax=ax1,
                    palette=palette,
                )
                sns.violinplot(
                    data=R_confidence,
                    x=("is_correct", "sum"),
                    y=("confidence", "std"),
                    scale="area",
                    inner="box",
                    ax=ax2,
                    palette=palette,
                )
                fig.suptitle(title, fontsize=16)
                confidence_graphs_dir = dir_path / "Confidence Graphs" / year
                confidence_graphs_dir.mkdir(parents=True, exist_ok=True)
                arxiv_title = title.replace(" ", "_")
                plt.savefig(confidence_graphs_dir / arxiv_title, bbox_inches="tight")


""" 
Unlike numpy.random.binomial(), this generates a matrix that represents the precise
independent reference distribution.  Concretely, this generates the datarame that would 
would be generated if each column were generated independently from the error rates of the columns in the provided df. 
"""


def generate_independent_df(df):
    num_rows = df.shape[0]
    reference_df = []
    error_rates = df.apply(lambda x: x.sum() / x.count())
    pmf = calculate_independent_pmf(error_rates)
    for outcome, probability in pmf:
        unrounded = probability * num_rows
        num_outcome_rows = round(unrounded)
        outcome_draws = [outcome] * num_outcome_rows
        reference_df.extend(outcome_draws)
    return pd.DataFrame(reference_df)


"""
The script can be called as such: "python3 hapi_matrix_graphs.py {dir_name}.
A results folder will be created with dir_name in the results folder. """
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("dir_name", help="name to be used for naming results folders")
    # parser.add_argument("--improvements", action="store_true")
    # parser.add_argument("--single_label", action="store_true")
    parser.add_argument("--task", action="append")
    args = parser.parse_args()

    main(
        args.dir_name,
        improvements=True,
        task_list=args.task,
        only_single_label=True,
    )
    print("All results generated")
