import argparse
import os

import numpy as np
import seaborn as sns
import torch
from matplotlib import pyplot as plt

best = {"cifar10": 0.8, "cifar100": 0.7}
print_options = {
    "Oracle": ("--", "o", "black"),
    "Finetune": ("--", "^", "cyan"),
    "Joint": ("--", "s", "purple"),
    "ER": ("-", "D", "green"),  # Diamond marker
    "EWC": ("-", "p", "orange"),  # Pentagon marker
    "Drift": ("-", "*", "red"),  # Star marker
    "GenCl": ("-", "X", "blue"),  # X marker
    "Gen": ("-", "x", "blue"),
    "Gate": ("-", ">", "brown"),
}
method_list = ["Oracle", "Finetune", "Joint", "ER", "EWC", "Gen"]


def load_acc(dir, n_tasks, test_batch_size, repeat=False):
    overall_acc = torch.zeros(n_tasks)
    for task_idx in range(n_tasks):
        load_idx = task_idx if not repeat else 0
        acc = torch.load(dir + f"task_idx_{load_idx}/batch_size_{test_batch_size}.pt")
        overall_acc[task_idx] = acc.mean().item()
    return overall_acc


def load_final_acc(dir, n_tasks, test_batch_size):
    acc = torch.load(dir + f"task_idx_{n_tasks}/batch_size_{test_batch_size}.pt")
    return acc.mean().item()


def print_dropout(dataset, test_batch_size):
    dataset_name = {"cifar10": "CIFAR10", "cifar100": "CIFAR100"}
    dir = f"./store/acc/{dataset}/"
    plt.figure(figsize=(10, 7))
    sns.set_context("talk")
    n_tasks = 5 if dataset == "cifar10" else 10
    # for dropout_prob in np.around(np.arange(0.1,1,0.1), decimals=2):
    #     overall_acc = torch.zeros(n_tasks)
    #     for task_idx in range(n_tasks):
    #         acc = torch.load(dir+f'/drop_{dropout_prob}/Drift/task_idx_{task_idx}/batch_size_{test_batch_size}.pt')
    #         overall_acc[task_idx] = acc.mean().item()
    #     print(overall_acc)
    #     sns.lineplot(x=range(len(overall_acc)), y=overall_acc.numpy(), label=f'Drift rate {dropout_prob}')
    #
    # # Title and labels
    # plt.title(f"{dataset}: average accuracy over time")
    # plt.xlabel("Number of tasks")
    # plt.ylabel("Accuracy")
    # plt.legend()
    #
    # # plt.savefig(f'store/figs/{name}_embedding.png')
    # plt.show()
    final_acc = []
    plt.figure(figsize=(10, 7))
    sns.set_context("talk")
    sns.set_style("darkgrid")
    items = os.listdir(dir)
    dropout_prob = [
        float(item.split("_")[1])
        for item in items
        if "drop" in item and "Drift" in os.listdir(dir + item)
    ]
    dropout_prob.sort()
    final_acc = [
        load_final_acc(dir + f"/drop_{drop}/Drift/", n_tasks - 1, test_batch_size)
        for drop in dropout_prob
    ]
    sns.lineplot(x=dropout_prob, y=final_acc)

    plt.title(f"{dataset_name[dataset]}: Test accuracy with different dropout rates")
    plt.xlabel("Dropout probability")
    plt.ylabel("Accuracy")
    if not os.path.exists(f"store/figs/"):
        os.mkdir(f"store/figs/")
    plt.savefig(f"store/figs/{dataset}_dropout.png")
    plt.show()


def print_all(dataset, test_batch_size):
    n_tasks = 5 if dataset == "cifar10" else 10
    dataset_name = {"cifar10": "CIFAR10", "cifar100": "CIFAR100"}
    dir = f"./store/acc/{dataset}"
    sns.set_context("talk")
    sns.color_palette("deep")

    plt.figure(figsize=(11, 6))
    plt.subplots_adjust(
        top=0.909, bottom=0.162, left=0.1, right=0.825, hspace=0.2, wspace=0.2
    )

    for method in method_list:
        overall_acc = load_acc(
            dir + f"/{method}/",
            n_tasks,
            test_batch_size,
            repeat=True if method == "Joint" else False,
        )
        linestyle, marker, color = print_options[method]
        sns.lineplot(
            x=np.arange(1, n_tasks + 1),
            y=overall_acc.numpy(),
            label=method,
            linestyle=linestyle,
            marker=marker,
            color=color,
        )

    best_idx = best[dataset]
    overall_acc = load_acc(dir + f"/drop_{best_idx}/Drift/", n_tasks, test_batch_size)
    linestyle, marker, color = print_options["Drift"]
    sns.lineplot(
        x=np.arange(1, n_tasks + 1),
        y=overall_acc.numpy(),
        label="Drift",
        linestyle=linestyle,
        marker=marker,
        color=color,
    )

    # Title and labels
    plt.title(f"{dataset_name[dataset]}: average accuracy over time")
    plt.xlabel("Number of tasks")
    plt.ylabel("Accuracy")
    plt.legend(loc="upper right", bbox_to_anchor=(1.25, 1))
    plt.xticks(torch.arange(1, n_tasks + 1, 1))
    plt.savefig(f"store/figs/{dataset}_all.png")
    plt.show()


def print_batch_size(dataset, **kwargs):
    n_tasks = 5 if dataset == "cifar10" else 10
    dataset_name = {"cifar10": "CIFAR10", "cifar100": "CIFAR100"}
    dir = f"./store/acc/{dataset}"
    sns.set_context("talk")
    sns.color_palette("deep")

    plt.figure(figsize=(11, 6))
    plt.subplots_adjust(
        top=0.909, bottom=0.162, left=0.1, right=0.825, hspace=0.2, wspace=0.2
    )

    for method in method_list:
        batch_size_list = [1, 2, 4, 8, 16, 32, 64, 128]
        final_acc = torch.tensor(
            [
                load_final_acc(
                    dir + f"/{method}/",
                    0 if method == "Joint" else n_tasks - 1,
                    test_batch_size,
                )
                for test_batch_size in batch_size_list
            ]
        )
        linestyle, marker, color = print_options[method]
        sns.lineplot(
            x=batch_size_list,
            y=final_acc.numpy(),
            label=method,
            linestyle=linestyle,
            marker=marker,
            color=color,
        )

    best_idx = best[dataset]
    final_acc = torch.tensor(
        [
            load_final_acc(
                dir + f"/drop_{best_idx}/Drift/", n_tasks - 1, test_batch_size
            )
            for test_batch_size in batch_size_list
        ]
    )
    linestyle, marker, color = print_options["Drift"]
    sns.lineplot(
        x=batch_size_list,
        y=final_acc.numpy(),
        label="Drift",
        linestyle=linestyle,
        marker=marker,
        color=color,
    )

    # Title and labels
    plt.title(f"{dataset_name[dataset]}: test accuracy with different test batch sizes")
    plt.xlabel("Test batch size")
    plt.ylabel("Accuracy")
    plt.xscale("log", base=2)
    plt.legend(loc="upper right", bbox_to_anchor=(1.25, 1))
    plt.xticks(batch_size_list, batch_size_list)

    plt.savefig(f"store/figs/{dataset}_batch.png")
    plt.show()


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Choose dataset")
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["cifar10", "cifar100"],
        help="Choose a dataset: cifar10, cifar100",
        default="cifar10",
    )
    parser.add_argument(
        "--test_batch_size", type=int, help="Test batch size", default=64
    )
    args = parser.parse_args()
    print_dropout(**vars(args))
    print_all(**vars(args))
    print_batch_size(**vars(args))
