import numpy as np
import os
import torch
import pandas as pd


def get_acc(output_root, data_name, backbone_name, model_names, eval_result_file):
    data_path = backbone_name + "_" + data_name
    output_path = os.path.join(output_root, data_path)

    eval_results = []

    for model_name in model_names:
        if backbone_name == "vgg16" and model_name == "FF":
            eval_results.append([0, 0, 0, 0, 0])
            continue
        eval_path = os.path.join(output_path, model_name)
        model_true_name = model_name
        if model_name == "FF":
            model_true_name = "fisher"
        elif model_name == "IU":
            model_true_name = "wfisher"
        eval_file = model_true_name + eval_result_file
        eval_result_path = os.path.join(eval_path, eval_file)
        if not os.path.exists(eval_result_path):
            print(eval_result_path)
            print("eval file does not!")
            break

        eval_result = torch.load(eval_result_path)
        new_accuracy = eval_result["accuracy"]
        eval_results.append(new_accuracy)

    return eval_results


if __name__ == "__main__":

    # For normal experiments
    path = "/nvme/data/3ai/lips/outputs-after"

    # For Ablaition experiments
    # path = '/nvme/data/3ai/lips/outputs-after-ablation-2-step04'

    # MU model before fine tune
    eval_file_unlearn_before = "eval_result.pth.tar"
    csv_file_unlearn_before = "unlearn_acc_before.csv"

    # MU model after fine tune
    eval_file_unlearn_after = "eval_result_ft.pth.tar"
    csv_file_unlearn_after = "unlearn_acc_after.csv"

    unlearn_model_names = ["retrain", "FT", "FF", "GA", "IU", "FT_prune"]
    data_list = ["cifar10", "cifar100", "tinyimg", "fmnist"]
    backbones = ["resnet18", "vgg16"]
    eval_all = {}

    # After fine tune
    (eval_file, csv_file) = (eval_file_unlearn_after, csv_file_unlearn_after)

    for data in data_list:
        for backbone in backbones:
            eval_result = get_acc(path, data, backbone, unlearn_model_names, eval_file)
            data_key = data + "_" + backbone
            eval_all[data_key] = eval_result

    df_eval = pd.DataFrame.from_dict(eval_all)
    df_eval.index = unlearn_model_names
    df_eval.to_csv(csv_file)
    print("save unlearn model's acc done!")
