import numpy as np
import os
import argparse
import os.path
import pandas as pd
import matplotlib.pyplot as plt
from os import path
import seaborn as sns
from matplotlib.lines import Line2D

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="getresults")
    config = parser.parse_args()
    algorithm_list =["PRL"] # PRL_FixEpsilon
    # colormap = {"PRL": "red", "SPL": "blue", "Standard": "green"} , "PRL", "PRL_InnerMax", "GaussianAug", "SPL", "SPL_InnerMax", "BootStrap", "BootStrap_InnerMax"
    corrupted_type = 'patch'
    max_epoch = 50
    corrupted_rate = [0.15,0.25, 0.35,0.45]
    r2_table = []
    auc_table = []
    head = ["*{0:<13}".format(" ")] + algorithm_list
    result_csv_test = []
    result_csv_test.append(["*"] + algorithm_list)
    result_csv_test = ['*', '0.15', '0.25', '0.35', '0.45']
    result_csv_test.append(result_csv_test)
    print(head)
    result_alg = []
    seed_list = [1,2,3]
    for data in ["cifar10"]:
        result = []
        LinePlot = []
        for corrupted_rate in corrupted_rate:
            for algorithm in algorithm_list:
                    result4algorithm = [algorithm]
                    acc_clean_all_seed = []
                    acc_poison_all_seed = []
                    attack_success_rate_all_seed = []
                    # if algorithm == 'PRL_SAM_OP' or algorithm == 'WeightRS':
                    #     corrupted_type = 'blend'
                    # else:
                    #     corrupted_type = 'blend'
                    clean_acc_avg = []
                    poison_acc_avg = []
                    for seed in seed_list:
                        # result_path = "./Result/{}/{}/{}/{}/{}result.npy".format(
                        #     data, corrupted_rate, algorithm, 'blend', seed)
                        result_path = "/localscratch/liuboya2/DefenseBackDoorAttack/Result/Result/{}/{}/{}/{}/{}result.npy".format(
                            data, corrupted_rate, algorithm, corrupted_type, seed)

                        print(result_path)
                        if os.path.isfile(result_path):
                            results = np.load(result_path, allow_pickle=True)
                            results_dict = results.item()
                            acc_clean = results_dict["acc_clean"][:max_epoch]
                            acc_poison = results_dict["acc_poison"][:max_epoch]
                            # success_rate = results_dict["attack_success_rate"][:max_epoch]
                            acc_clean_all_seed.append(acc_clean)
                            acc_poison_all_seed.append(acc_poison)
                            # attack_success_rate_all_seed.append(success_rate)
                            print(acc_poison)

                    acc_clean_last_ten = np.stack(acc_clean_all_seed)[:,max_epoch-10:]
                    # acc_clean_std = np.stack(acc_clean_all_seed)[:, max_epoch - 10:].std()
                    clean_acc_avg.append(acc_clean_last_ten)
                    acc_poison_last_ten = np.stack(acc_poison_all_seed)[:,max_epoch-10:]
                    # acc_poison_std = np.stack(acc_poison_all_seed)[:, max_epoch - 10:].std()
                    poison_acc_avg.append(acc_poison_last_ten)
                    # attack_success_rate_avg = np.stack(acc_poison_all_seed)[:,max_epoch-10:].mean()
                    # attack_success_rate_std = np.stack(acc_poison_all_seed)[:, max_epoch - 10:].std()

                    result4algorithm.append("{}/{}/{}")
                    acc_clean_all_seed = np.concatenate(acc_clean_all_seed)
                    acc_poison_all_seed = np.concatenate(acc_poison_all_seed)

                    plt.subplot(1, 2, 1)
                    print("plotting clean acc of {}".format(algorithm))
                    sns.lineplot(
                        x=np.array(list(range(acc_clean.__len__())) * len(seed_list))+1,
                        y=acc_clean_all_seed.astype(float),
                        legend='brief',
                        label=algorithm,
                        ci="sd",
                    )
                    plt.ylim([0,1])

                    print("plotting poison acc of {}".format(algorithm))
                    plt.subplot(1, 2, 2)
                    sns.lineplot(
                        x=np.array(list(range(acc_clean.__len__())) * len(seed_list))+1,
                        y=acc_poison_all_seed.astype(float),
                        legend='brief',
                        label=algorithm,
                        ci="sd",
                    )
                    plt.ylim([0, 1])

                    print('clean acc of {}:{}/{}'.format(algorithm, np.array(clean_acc_avg).mean(), np.array(clean_acc_avg).std()))
                    print('poison acc of {}:{}/{}'.format(algorithm, np.array(poison_acc_avg).mean(), np.array(poison_acc_avg).std()))
            # plt.show()

            fig = plt.gcf()
            fig.set_size_inches(18.5, 10.5)
            plt.suptitle("corruption rate:{}".format(corrupted_rate))
            plt.show()
    # result_alg = np.stack(result_alg)
    # sns.set_theme(style="whitegrid")
    # plt.subplot(1, 2, 1)
    # sns.lineplot(
    #     x=result_alg[:, 0],
    #     y=result_alg[:, 1].astype(float),
    #
    #     # markers=True,
    #     # style=result_alg[:, 0],
    #     # hue=result_alg[:, 0],
    #     # err_style="bars",
    # )
    # plt.xlabel("epoch")
    # plt.ylabel("acc")
    # plt.title("clean test data")
    # plt.subplot(1, 2, 2)
    # sns.lineplot(
    #     x=result_alg[:, 0],
    #     y=np.exp(result_alg[:, 2].astype(float)),
    #     # markers=True,
    #     # style=result_alg[:, 0],
    #     # hue=result_alg[:, 0],
    #     # err_style="bars",
    # )
    # # sns.lineplot(
    # #     x=result_alg[:, 3],
    # #     y=(result_alg[:, 2].astype(float)),
    # #     markers=True,
    # #     style=result_alg[:, 0],
    # #     hue=result_alg[:, 0],
    # #     err_style="bars",
    # # )
    # plt.title("adversarial attacked test data")
    # plt.xlabel("epoch")
    # plt.ylabel("acc")
    # fig = plt.gcf()
    # fig.set_size_inches(18.5, 10.5)
    # plt.show()
    # print("figure")
