import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from eval.print_results import load_results


def visualization(
    dataset_list,
    method_list,
    cl_type="regular",
    show_legend=True,
    dir="./store/results/",
):
    # Set the overall layout and font size
    sns.set_theme(style="whitegrid")
    plt.rcParams.update({"font.size": 28})
    # sns.set_context("paper")
    # Define a dictionary for colors and line styles
    style_dict = {
        "Oracle": {"color": "darkgray", "linestyle": "--", "marker": "o"},
        "Finetune": {"color": "lightseagreen", "linestyle": "--", "marker": "v"},
        "ExpVAE": {"color": "burlywood", "linestyle": "--", "marker": "s"},
        "ER": {"color": "cornflowerblue", "linestyle": "solid", "marker": "D"},
        "AGEM": {"color": "seagreen", "linestyle": "solid", "marker": "^"},
        "OnlineEWC": {"color": "darkorange", "linestyle": "solid", "marker": "P"},
        "MIR": {"color": "darkgray", "linestyle": "solid", "marker": "X"},
        "Gdumb": {"color": "lightgreen", "linestyle": "solid", "marker": "h"},
        "LEARN": {"color": "red", "linestyle": "solid", "marker": "p"},
    }

    # Create lineplot with error bars
    fig, axs = plt.subplots(
        nrows=1, ncols=4, figsize=(16, 3.1)
    )  # Change the figsize as needed
    dataset_names = dataset_list  # , 'tinyimagenet']
    print_names_list = {
        "cifar10": "CIFAR10",
        "cifar100": "CIFAR100",
        "miniimagenet": "Mini-ImageNet",
        "tinyimagenet": "Tiny-ImageNet",
        "cub": "CUB-200",
    }
    print_names = [
        print_names_list[name] for name in dataset_names
    ]  # , 'tinyimagenet']
    sns.set_palette("deep")

    for i, (dataset_name, print_name) in enumerate(zip(dataset_names, print_names)):
        # row = i // 2
        # col = i % 2
        # ax = axs[row, col]
        ax = axs[i]
        avg_acc_list = load_results(
            cl_type=cl_type, dataset_name=dataset_name, method_list=method_list
        )["avg_acc"]
        for name, avg_acc_list_ in avg_acc_list.items():
            timestamps = np.arange(500, avg_acc_list_.shape[0], 200)
            df = pd.DataFrame(
                avg_acc_list_[500::200].detach().cpu().numpy().T * 100,
                columns=timestamps,
            )
            df_melted = df.melt(var_name="Time", value_name="Accuracy (%)")
            df_melted["All Runs"] = ""
            n = len(df_melted)
            if name == "LEARN":
                name_ = "LEARN (ours)"
            elif name == "Gdumb":
                name_ = "GDumb"
            else:
                name_ = name
            # if name == 'LEARN':
            sns.lineplot(
                x="Time",
                y="Accuracy (%)",
                data=df_melted,
                label=name_,
                ax=ax,
                color=style_dict[name]["color"],
                linestyle=style_dict[name]["linestyle"],
                markers=style_dict[name]["marker"],
                errorbar="se",
            )
            # else:
            #     sns.lineplot(
            #         x="Time",
            #         y="Accuracy (%)",
            #         data=df_melted,
            #         label=name_,
            #         ax=ax,
            #         # color=style_dict[name]["color"],
            #         linestyle=style_dict[name]["linestyle"],
            #         markers=style_dict[name]["marker"],
            #         errorbar="se",
            #     )

        ax.set_title(print_name, fontsize=23)
        box = ax.get_position()
        ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])

        if i == 0:
            ax.legend(loc="upper left", bbox_to_anchor=(-0.75, 1.05))
        else:
            ax.legend().remove()
            # ax.get_yaxis().set_visible(False)
            ax.yaxis.visible = False
    # axs[1, 2].axis("off")

    plt.tight_layout()
    plt.subplots_adjust(
        top=0.896, bottom=0.124, left=0.14, right=0.993, hspace=0.5, wspace=0.2
    )
    plt.savefig(dir + f"{cl_type}_acc.pdf", format="pdf", dpi=300, bbox_inches="tight")
    plt.show()
    plt.close()
