import statistics

import matplotlib.pyplot as plt

from uq_diagcfm.checkpointing import load_run_info_according_to_criteria
from uq_diagcfm.paths import PAPER_FIGURES_DIR, ensure_paper_dirs_exist


def main(run_criteria: dict):

    run_info_list = load_run_info_according_to_criteria(run_criteria)

    buckets = dict()
    overall_min = float("inf")
    overall_max = float("-inf")
    for run_info in run_info_list:
        # Skip INN models (they don't have diag_cfm key)
        if "diag_cfm" not in run_info:
            continue
        diag_cfm = run_info["diag_cfm"]
        if diag_cfm not in buckets:
            buckets[diag_cfm] = dict()
        seed = run_info["shuffle_params_seed"]
        if seed not in buckets[diag_cfm]:
            buckets[diag_cfm][seed] = list()
        last_surrogate_loss = (
            run_info["val_surrogate_loss_trajectory"][-1]
            if "val_surrogate_loss_trajectory" in run_info
            and isinstance(run_info["val_surrogate_loss_trajectory"], list)
            and run_info["val_surrogate_loss_trajectory"]
            else None
        )
        last_test_loss = (
            run_info["val_loss_trajectory"][-1] if last_surrogate_loss is None else None
        )
        buckets[diag_cfm][seed].append(
            last_surrogate_loss if last_surrogate_loss is not None else last_test_loss
        )
        overall_min = min(overall_min, buckets[diag_cfm][seed][-1])
        overall_max = max(overall_max, buckets[diag_cfm][seed][-1])

    print("\nMeans and Standard Deviations:")
    for diag_cfm in buckets:
        cfm_str = "Diag-CFM" if diag_cfm else "CFM"
        print(f"  {cfm_str}:")
        for seed in buckets[diag_cfm]:
            losses = buckets[diag_cfm][seed]
            mean_loss = sum(losses) / len(losses)
            std_loss = statistics.stdev(losses) if len(losses) > 1 else 0.0
            print(f"    Seed: {seed}, Mean: {mean_loss:.2e}, Std: {std_loss:.2e}, N: {len(losses)}")

    axis_limits = (
        overall_min - 0.2 * abs(overall_min),
        overall_max + 0.2 * abs(overall_max),
    )

    # plotting
    x_ticks = []
    count = 0
    # Define a consistent color map for seeds
    color_map = {}
    for diag_cfm in buckets:
        seed_dicts = buckets[diag_cfm]
        for seed_i, seed in enumerate(seed_dicts):
            if seed not in color_map:
                color_map[seed] = plt.cm.tab10(
                    seed_i % 10
                )  # Use a colormap for consistent colors

    plt.figure()
    for diag_cfm in buckets:
        seed_dicts = buckets[diag_cfm]
        marker_style = (
            "^" if diag_cfm else "v"
        )  # Triangles for Diag-CFM, inverted triangles for CFM
        # create a plot where x axis is the seed and y axis is the several losses for that seed
        for seed_i, seed in enumerate(seed_dicts):
            cfm_str = "Diag-CFM" if diag_cfm else "CFM"
            losses = seed_dicts[seed]

            plt.scatter(
                [count] * len(losses),
                losses,
                label=f"{cfm_str} w. Ord. {seed_i + 1}",
                color=color_map[seed],
                marker=marker_style,
            )
            x_ticks.append(f"{cfm_str} w. O. {seed_i + 1}")
            count += 1

    plt.ylim(axis_limits)
    # log scale for y axis
    plt.yscale("log")
    plt.ylabel("Final Round-Trip Error")
    plt.legend()
    plt.grid()
    # Replace x-axis values with labels
    plt.xticks(
        range(len(x_ticks)),
        x_ticks,
    )
    # Rotate x-axis tick labels slightly
    plt.xticks(rotation=15)

    ensure_paper_dirs_exist()
    svg_path = PAPER_FIGURES_DIR / "ablation_diag_cfm.svg"
    pdf_path = PAPER_FIGURES_DIR / "ablation_diag_cfm.pdf"
    plt.savefig(svg_path, bbox_inches="tight")
    plt.savefig(pdf_path, bbox_inches="tight")
    print(f"\nPlot saved to: {svg_path}")
    print(f"Plot saved to: {pdf_path}")
    plt.close()
    return


if __name__ == "__main__":
    import sys
    from uq_diagcfm.data_utils_gas_turbine import GAS_TURBINE_DATASET_NAME
    from uq_diagcfm.data_utils_unifoil import UNIFOIL_DATASET_NAME

    if len(sys.argv) == 2 and sys.argv[1] == "ablation_diag_cfm":
        criteria = {"epochs": 20, "dataset": GAS_TURBINE_DATASET_NAME}
        # criteria = {"epochs": 100, "dataset": UNIFOIL_DATASET_NAME}
        main(criteria)
