import warnings
import pandas as pd
import numpy as np

# load all the file from directory of 'experiment_results/CL'
# each of them is a folder than contains a csv file with the results of the experiment
import os

from scipy import stats
from dataset.CelebADataset import CelebADataset
from dataset.FairFaceDataset import FairFaceDataset
from dataset.MTFLDataset import MTFLDataset
from dataset.PhysiQDataset import PhysiQDataset
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from eval import eval_all_metrics, eval_final_row
from utilities.Logger import Logger, read_logger
from argparser import get_args
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt

TSR_LIST = [
        # 0.1,
        # 0.15,
        # 0.2,
        # 0.25,
        # 0.3,
        # 0.35,
        # 0.4,
        # 0.45,
        0.5,
        0.55,
        0.6,
        0.65,
        0.7,
        0.75,
        0.8,
        0.85,
        0.9,
        1.0,
    ]
def remove_keys(d, keys_to_remove):
    return {key: value for key, value in d.items() if key not in keys_to_remove}


def analyze_and_plot(path, args):
    tsr = "train_subsample_ratio"
    loggers = read_logger(path)
    # TODO: check if there is more than one logger with the same args
    # if len(loggers) > 1:
    #     warnings.warn(
    #         "There are multiple loggers with the same arguments. This may lead to incorrect results."
    #     )

    if args["dataset"] == "celeba":
        prediction_targets = CelebADataset.prediction_targets
    elif args["dataset"] == "physiq":
        prediction_targets = PhysiQDataset.prediction_targets
    elif args["dataset"] == "fairface":
        prediction_targets = FairFaceDataset.prediction_targets
    else:
        raise ValueError("Dataset not found.")
    # for every 10 row dict, get the results of the prediction targets

    logger_counter = 0
    # tsr = 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9
    
    dictres = [None] * len(TSR_LIST)
    for each_logger in loggers:
        print(each_logger.args[tsr])
        if tsr in each_logger.args:
            tsrr = each_logger.args[tsr]
            if tsrr in TSR_LIST:
                # res = eval_all_metrics(each_logger, prediction_targets)
                res = eval_final_row(
                    each_logger, prediction_targets, display=False
                )
                try:
                    dictres[TSR_LIST.index(tsrr)] = res
                except:
                    pass
    return dictres


def main():
    res_string = ""
    # plt.figure(figsize=(10, 6))
    fig1, ax1 = plt.subplots(figsize=(10, 6))
    predefined_colors = [
        "#2196F3",  # Muted Blue-Grey
        "#A5D6A7",  # Muted Green-Grey
        "#FFAB91",  # Muted Orange-Grey
        "#FFE082",  # Muted Yellow-Grey
        "#CE93D8",  # Muted Purple-Grey
        "#80DEEA",  # Muted Cyan-Grey
        "#C5E1A5",  # Muted Light Green-Grey
        "#FFF59D",  # Muted Bright Yellow-Grey
        "#FFCC80",  # Muted Orange-Grey
    ]
    metric_indx = 1
    for model_name in ["lwp", "der", "er", "lwf", "si", "fdr", "derpp"]:
        args = get_args()
        args["job"] = "cl"
        args["dataset"] = "physiq"
        args["model"] = model_name
        save_results = dict()
        for root, dirs, files in os.walk(
            os.path.join(
                "experiment_results",
                args["job"].lower(),
                args["dataset"].lower(),
                args["model"].lower(),
            )
        ):
            for file in files:
                if file.endswith(".pyd"):
                    # get the last dir in root, using os separater?

                    save_results[root.split(os.path.sep)[-1]] = os.path.join(
                        root, file
                    )
        
        for key, value in save_results.items():
            res = analyze_and_plot(value, args)
            # sort dict res:
            # print([np.nanmean(v[metric_indx]) for k,v in res.items()])
            # res_string += (
            #     f"{model_name} {[np.nanmean(v[metric_indx]) for k,v in res.items()]}\n"
            # )
            color = predefined_colors.pop(0)
            # print(res)
            y_list = np.array([np.nanmean(v[metric_indx]) if v is not None else None for v in res])
            # smooth y_list:
            # y_list = pd.Series(y_list).rolling(3,min_periods=1).mean()
            ax1.plot(
                [1472 * i for i in TSR_LIST],
                y_list,
                label=model_name,
                color=color,
            )

            z_score = stats.norm.ppf(1 - (1 - 0.5) / 2)  # cv = .5
            lower_bound = np.array([
                np.nanmean(v[metric_indx]) - z_score * stats.sem(v[metric_indx], axis=None)
                 if v is not None else 0 for v in res
            ])
            upper_bound = np.array([
                np.nanmean(v[metric_indx]) + z_score * stats.sem(v[metric_indx], axis=None)
                 if v is not None else 0 for v in res
            ])
            # Then call the fill_between function
            ax1.fill_between(
                [1472 * i for i in TSR_LIST], lower_bound, upper_bound, alpha=0.2, color=color
            )
    # draw a vertical red line at .7 (buffer size == batch size):
    ax1.axvline(x=0.8 * 1472, color="r", linestyle="--")
    # ax1.set_xticks([1472 * i for i in TSR_LIST])
    # draw grid:
    ax1.grid()
    plt.legend(loc="lower left", ncol=3)
    plt.xlabel("Total Training Samples", fontsize=14)
    plt.ylabel("Test ECE", fontsize=14)
    plt.xlim(TSR_LIST[metric_indx] * 1472, TSR_LIST[-1] * 1472)
    plt.tight_layout()
    # save fig
    plt.savefig("./figures/physiq_training_sample.png",dpi=300)
    plt.show()
    # print(res_string)


main()
