import warnings
import pandas as pd
import numpy as np

# load all the file from directory of 'experiment_results/CL'
# each of them is a folder than contains a csv file with the results of the experiment
import os

from scipy import stats
from dataset.CelebADataset import CelebADataset
from dataset.FairFaceDataset import FairFaceDataset
from dataset.MTFLDataset import MTFLDataset
from dataset.PhysiQDataset import PhysiQDataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from utilities.Logger import Logger, read_logger
from argparser import get_args
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import re


def remove_keys(d, keys_to_remove):
    return {key: value for key, value in d.items() if key not in keys_to_remove}


def extract_numbers_from_save_name(string):
    return int(re.search(r"\d+", string).group())


def analyze_and_plot(path, args):
    loggers = read_logger(path)
    # TODO: check if there is more than one logger with the same args
    # if len(loggers) > 1:
    #     warnings.warn(
    #         "There are multiple loggers with the same arguments. This may lead to incorrect results."
    #     )
    logger_counter = 0
    for each_logger in loggers:

        if remove_keys(each_logger.args, ["model"]) == remove_keys(
            args, ["model"]
        ):
            pass
        else:
            # check what keys are different
            if len(loggers) <= 3:
                print(
                    "Potential missing keys:",
                    set(each_logger.args.keys()).symmetric_difference(
                        set(args.keys())
                    ),
                )
                print(
                    "Potential different values:",
                    {
                        key: each_logger.args[key]
                        for key in each_logger.args.keys()
                        if each_logger.args[key] != args[key]
                    },
                )
        logger_counter += 1
    if logger_counter == 0:
        print("No logger found with the specified arguments.")
        return
    if logger_counter > 1:
        print(
            f"\033[91m{logger_counter} loggers found with the same arguments. Please select the logger to analyze.\033[0m"
        )
        inp = input(
            f"Multiple loggers found with the same arguments. Please select the logger to analyze (0-{logger_counter-1}): "
        )
        each_logger = loggers[int(inp)]
    if args["dataset"] == "celeba":
        prediction_targets = CelebADataset.prediction_targets
    elif args["dataset"] == "physiq":
        prediction_targets = PhysiQDataset.prediction_targets
    elif args["dataset"] == "fairface":
        prediction_targets = FairFaceDataset.prediction_targets
    else:
        raise ValueError("Dataset not found.")

    # if save_name is given, parse the num of tasks and prune out prediction targets
    if len(args["save_name"]) > 0:
        num_tasks = extract_numbers_from_save_name(args["save_name"])
        prediction_targets = prediction_targets[:num_tasks]
    # for every 10 row dict, get the results of the prediction targets
    if args["dataset"] in ["physiq"]:
        res = eval_final_row(each_logger, prediction_targets)
    else:
        res = eval_all_metrics(each_logger, prediction_targets)
    if args["job"] == "cl":
        metric_cm = get_cm_from_logger(
            each_logger, prediction_targets, each_logger.args["num_seed"]
        )  # acc by default
        # metric_diff, diff_std = get_metric_diff(
        #     each_logger, prediction_targets
        # )  # acc by default

        metric_diff, diff_std = get_backward_transfer(
            each_logger, prediction_targets
        )
    else:
        metric_cm = None
        metric_diff = None
        diff_std = None
    return {
        "res": res,
        "metric_diff": metric_diff,
        "diff_std": diff_std,
        "cm": metric_cm,
        "prediction_targets": prediction_targets,
    }


def eval_final_row(each_logger, prediction_targets, display=True):
    # final row for each seed:
    res = []
    for each_metrics in ["", "_ece", "_f1", "_precision", "_recall"]:
        res_temp = []
        for each_row_dict in each_logger.results[::]:
            if all(
                [
                    each_row_dict.get(f"{each_target}{each_metrics}", None)
                    is not None
                    for each_target in prediction_targets
                ]
            ):
                res_temp.append(
                    [
                        each_row_dict.get(f"{each_target}{each_metrics}", None)
                        for each_target in prediction_targets
                    ]
                )

        np_res_temp = np.array(res_temp, dtype=np.float32)
        res.append(np_res_temp)
        if display:
            display_results(each_metrics, np_res_temp)
    if display:
        display_memory_usage(each_logger)
    return np.array(res)


def get_cm_from_logger(each_logger, prediction_targets, num_seed, metric=""):
    res = np.empty((num_seed, len(prediction_targets), len(prediction_targets)))
    res[:] = np.nan
    x_index = 0
    seed_index = 0
    for each_row_dict in each_logger.results[::]:
        for y, each_target in enumerate(prediction_targets):
            res[seed_index, x_index, y] = each_row_dict.get(
                f"{each_target}{metric}", None
            )
        x_index += 1
        if x_index == len(prediction_targets):
            # next seed
            x_index = 0
            res[seed_index] = np.rot90(res[seed_index], k=-3)
            seed_index += 1
    # np.nanmean(dict["cm"], axis=0)
    return np.nanmean(np.array(res, dtype=np.float32), axis=0)


def eval_all_metrics(each_logger, prediction_targets):
    res = []
    for each_metrics in ["", "_ece", "_f1", "_precision", "_recall"]:
        res_temp = []
        for each_row_dict in each_logger.results[::]:
            res_temp.append(
                [
                    each_row_dict.get(f"{each_target}{each_metrics}", None)
                    for each_target in prediction_targets
                ]
            )
        np_res_temp = np.array(res_temp, dtype=np.float32)
        res.append(np_res_temp)
        display_results(each_metrics, np_res_temp)
    display_memory_usage(each_logger)
    display_time_usage(each_logger)
    return np.array(res)


def get_metric_diff(each_logger, prediction_targets, metric=""):
    # displays diagram of how accuracy degrades over task iteration
    # (catastrophic forgetting in continual learning setup)
    res_temp = []
    for each_row_dict in each_logger.results[::]:
        res_temp.append(
            [
                each_row_dict.get(f"{each_target}{metric}", None)
                for each_target in prediction_targets
            ]
        )
    res_temp = np.array(res_temp, dtype=np.float32)
    # (num_seed, task_iteration, prediction_targets)
    res_temp = res_temp.reshape(
        (-1, len(prediction_targets), len(prediction_targets))
    )
    # res_temp = np.mean(res_temp,axis=0)
    res_temp_diff = np.diff(
        res_temp, axis=1
    )  # (num_seed, task_iteration - 1, prediction_targets)
    mean_diffs = np.nanmean(res_temp_diff, axis=(1, 2))  # (num_seed)
    mean_diff = np.nanmean(mean_diffs)
    std_diff = np.nanstd(mean_diffs)
    # mean_diff = np.nanmean(mean_diffs)
    # std_diff = np.nanstd(mean_diffs)  # Standard deviation of the differences
    # n = np.sum(~np.isnan(mean_diffs))  # Count of non-NaN elements

    # # Compute the 95% confidence interval
    # confidence_level = 0.95
    # z_score = stats.norm.ppf((1 + confidence_level) / 2)  # Z-score for 95% confidence

    # margin_of_error = z_score * (std_diff / np.sqrt(n))
    # std_diff = margin_of_error

    return mean_diff, std_diff


def get_backward_transfer(each_logger, prediction_targets, metric=""):
    # displays diagram of how accuracy degrades over task iteration
    # (catastrophic forgetting in continual learning setup)
    res_temp = []
    for each_row_dict in each_logger.results[::]:
        res_temp.append(
            [
                each_row_dict.get(f"{each_target}{metric}", None)
                for each_target in prediction_targets
            ]
        )
    res_temp = np.array(res_temp, dtype=np.float32)
    # (num_seed, task_iteration, prediction_targets)
    res_temp = res_temp.reshape(
        (-1, len(prediction_targets), len(prediction_targets))
    )
    # last iteration - diag
    res_temp_diff = res_temp[:, -1, :] - np.diagonal(res_temp, axis1=1, axis2=2)
    # last one doesn't count
    res_temp_diff = res_temp_diff[:, :-1]  # (num_seed, num_tasks - 1)
    mean_diffs = np.nanmean(res_temp_diff, axis=1)  # (num_seed)
    mean_diff = np.nanmean(mean_diffs)
    std_diff = np.nanstd(mean_diffs)

    return mean_diff, std_diff


def plot_forward_info_loss_diagram(diffs, stds, models, args):
    """
    Creates and displays a bar plot showing the average difference in accuracy (or another metric)
    over time for each model. The bars start from 0 in the middle of the graph and can
    have both positive and negative values.
    """
    # Define a list of colors (extend this list if more unique colors are needed)
    predefined_colors = [
        "#2196F3",  # Muted Blue-Grey (ours)
        "#A5D6A7",  # Muted Green-Grey
        "#FFAB91",  # Muted Orange-Grey
        "#FFE082",  # Muted Yellow-Grey
        "#CE93D8",  # Muted Purple-Grey
        "#80DEEA",  # Muted Cyan-Grey
        "#C5E1A5",  # Muted Light Green-Grey
        "#FFF59D",  # Muted Bright Yellow-Grey
        "#FFCC80",  # Muted Orange-Grey
    ]

    # Repeat the predefined colors if the number of models exceeds the number of colors
    colors = (predefined_colors * (len(models) // len(predefined_colors) + 1))[
        : len(models)
    ]

    # Create the bar plot using the generated colors list
    plt.figure(figsize=(7, 4))
    bars = plt.bar(
        x=models,
        height=diffs,
        yerr=stds,
        ecolor="#424242",
        capsize=5,
        color=colors,
        width=0.8,  # Increased width
    )

    # Rotate the x-tick labels vertically
    plt.xticks(rotation=0, fontsize=12)

    # Draw a horizontal line at y=0 for reference
    plt.axhline(0, color="black", linewidth=0.8)

    # Label the axes and the plot
    plt.xlabel("Model", fontsize=14)
    plt.ylabel(
        "Average Diff in Accuracy from t-1 to t",
        fontsize=14,
    )
    plt.title(
        f"Backward Transfer Diagram ({args['dataset']})",
        fontsize=14,
    )

    # Adding a grid for better readability
    plt.grid(True, axis="y", linestyle="--", linewidth=0.7)

    # Annotate each bar with its value, rotated vertically
    for bar in bars:
        height = bar.get_height()
        plt.annotate(
            f"{height:.4f}",
            xy=(bar.get_x() + bar.get_width() / 2, 0),
            xytext=(0, 0),  # 3 points vertical offset
            textcoords="offset points",
            ha="center",
            va="bottom",
            fontsize=7,
        )

    # Set y-axis limits to make the graph more visually appealing
    plt.ylim(-np.max(np.abs(diffs)) * 1.5, np.max(np.abs(diffs)) * 0.25)

    # Minimize extra borders
    # plt.box(False)

    # Display the plot
    plt.tight_layout()
    plt.savefig(
        f"figures/{args['dataset']}_backward_transfer_diagram.png",
        dpi=300,
    )


def plot_confusion_matrix_diagram(
    conf_matrices, dataset, model_names, prediction_targets
):
    # List of all confusion matrices
    conf_matrices = np.array(conf_matrices)
    conf_matrices = np.round(conf_matrices, 4)
    min_per_row = np.nanmin(conf_matrices, axis=(0, 2))
    max_per_row = np.nanmax(conf_matrices, axis=(0, 2))

    # Initialize an array to store the normalized matrices
    normalized_cm = np.empty_like(conf_matrices)

    # Perform normalization
    for i, cm in enumerate(conf_matrices):
        range_per_row = max_per_row - min_per_row
        normalized_cm[i] = np.where(
            np.isnan(cm),  # If the value is NaN
            np.nan,  # Keep it as NaN
            np.where(
                range_per_row[:, None]
                <= 0.005,  # Condition for too small range
                1,  # Value to assign when the condition is True
                (cm - min_per_row[:, None])
                / range_per_row[
                    :, None
                ],  # Normalization when condition is False
            ),
        )
    if not os.path.exists("figures/cm"):
        os.makedirs("figures/cm")
    for i, cm in enumerate(normalized_cm):
        if False:  # i == 0:
            plt.figure(
                figsize=(7, 4)
            )  # Increase width to account for y-tick labels
        else:
            plt.figure(figsize=(5, 5))
        g = sns.heatmap(
            cm,
            annot=np.round(conf_matrices[i], 4) * 100,
            cmap="Blues",
            cbar=False,
            vmin=np.nanmin(normalized_cm),
            vmax=np.nanmax(normalized_cm),
            fmt=".3g",
            linewidths=0.5,
            yticklabels=False,  # prediction_targets[::-1] if i == 0 else False,
        )
        g.set_facecolor("xkcd:light grey")

        if False:  # i == 0:
            plt.yticks(rotation=0, fontsize=5)
            plt.subplots_adjust(
                left=0.2, right=0.9
            )  # Adjust to account for y-tick labels
        else:
            plt.yticks(rotation=60, fontsize=5)
            plt.subplots_adjust(
                left=0.2, right=0.9
            )  # Adjust to account for y-tick labels
        plt.tight_layout()
        plt.savefig(
            f"figures/cm/{dataset}_{model_names[i]}_confusion_matrix.png",
            dpi=300,
        )


def display_time_usage(each_logger):
    # only print all the time usage (not the last row in CL)
    times = [
        each_row_dict["time"]
        for each_row_dict in each_logger.results
        if "time" in each_row_dict
    ]
    time = np.array(times)
    if len(time) == 0:
        print("time usage: N/A")
    else:
        print("time usage:", np.mean(time), np.std(time))


def display_memory_usage(each_logger):
    gpucpu = np.array(each_logger.cpu_res) + np.array(
        [v for gpu in each_logger.gpu_res for k, v in gpu.items()]
    )
    print("cpu+gpu memory usage:", np.mean(gpucpu), np.std(gpucpu))
    return


def display_results(each_metrics, np_res_temp):
    if each_metrics in ["", "_f1"]:
        # percentage, with 2 decimal points
        show_name = "Accuracy" if each_metrics == "" else "F1"
        print(
            f"\t{show_name}: {np.nanmean(np_res_temp) * 100:.3f} $\pm$ {np.nanstd(np_res_temp) * 100:.3f}"
        )
    else:
        print(
            f"\t{each_metrics}: {np.nanmean(np_res_temp):.3f} $\pm$ {np.nanstd(np_res_temp):.3f}"
        )
    return


# python eval.py --job cl --dataset celeba  -is 32 --num_seed 5 --save_name 10tasks
if __name__ == "__main__":
    args = get_args()
    save_results = dict()
    if args["dataset"] == "celeba":
        prediction_targets = CelebADataset.prediction_targets
    elif args["dataset"] == "physiq":
        prediction_targets = PhysiQDataset.prediction_targets
    elif args["dataset"] == "fairface":
        prediction_targets = FairFaceDataset.prediction_targets
    else:
        raise ValueError("Dataset not found.")
    for root, dirs, files in os.walk(
        # os.path.join(
        #     "./experiment_results", args["job"].lower(), args["dataset"].lower()
        # )
        # activate this insead if linux
        os.path.join(
            "experiment_results", args["job"].upper(), args["dataset"].lower()
        )
    ):
        for file in files:
            if "lwp" in root and root.split(os.path.sep)[-1] != "lwp":
                continue
            if str(len(prediction_targets)) not in file:
                continue
            if file.endswith(f"{args['save_name']}.pyd"):
                # get the last dir in root, using os separater?
                print(root)
                if root.split(os.path.sep)[-3] == "lwp":
                    model_name = "_".join(root.split(os.path.sep)[-3:])
                else:
                    model_name = root.split(os.path.sep)[-1]

                save_results[model_name] = os.path.join(root, file)

    # df_mtl = pd.read_csv("experiment_results/MTL/celeba/mtl/10tasks.csv")
    print(save_results.keys())
    models = []
    metric_diffs = []
    diff_stds = []
    all_cm = []
    # sorted by save_results.keys() but lwp first:
    keys = list(save_results.keys())
    keys.sort()
    if "lwp" in keys:
        keys.remove("lwp")
        keys = ["lwp"] + keys
    print(keys)
    for key in keys:
        print(
            "--------------------------------------------------------------------"
        )
        print(key.upper())
        dict = analyze_and_plot(save_results[key], args)
        if key != "single":  # it doesn't aply for single
            models.append(key.upper())
            metric_diffs.append(dict["metric_diff"])
            diff_stds.append(dict["diff_std"])
            all_cm.append(dict["cm"])
    predictions = dict["prediction_targets"]
    predictions = [x.replace("_", " ") for x in predictions]
    # if args["job"] == "cl":
    plot_confusion_matrix_diagram(
        all_cm, args["dataset"].lower(), models, predictions)
    plot_forward_info_loss_diagram(metric_diffs, diff_stds, models, args)
