import numpy as np
import torch

from utils import get_folder_names


def load_results(
    method_list,
    cl_type="regular",
    dir="./store/results/",
    dataset_name="cifar10",
):
    result_types = ["avg_acc", "adapt_acc", "test_acc_list"]
    results = {result_type: {} for result_type in result_types}
    if cl_type == "ablation":
        method_list = get_folder_names(dir + cl_type + "/" + dataset_name + "/", "/")
        method_list = [name for name in method_list if "disable" in name]
    for name in method_list:
        for result_type in result_types:
            filename = f"{dir}{cl_type}/{dataset_name}/{name}_{result_type}.pt"
            results[result_type][name] = torch.load(filename)

    return results


def print_results(
    method_list,
    cl_type="regular",
    dir="./store/results/",
    dataset_name="cifar10",
):
    results = load_results(method_list, cl_type, dir, dataset_name)
    torch.set_printoptions(precision=2)

    sections = {
        "Adaptive": ("adapt_acc", ""),
        "Knowledge": ("test_acc_list", ""),
        "Final accuracy": ("avg_acc", "[-1, :]"),
    }

    for section_name, (result_type, operation) in sections.items():
        print("-" * 30, section_name, "-" * 30)
        for name, result in results[result_type].items():
            result_ = eval(f"result{operation}")
            if isinstance(result_, list):
                result_ = torch.stack(result_)
            result_ *= 100
            mean = torch.mean(result_).item()
            std_err = torch.std(result_).item() / np.sqrt(len(result_))
            print(f"{name}: ${mean:.2f} \pm {std_err:.2f}$")
