import numpy as np
import os
import argparse
import os.path
import pandas as pd
import matplotlib.pyplot as plt
from os import path
import seaborn as sns
from matplotlib.lines import Line2D
import logging

os.system('mkdir -p ./Plots')
logging.basicConfig(filename='./Plots/cifar100.log', filemode='w', format='%(message)s')

logger=logging.getLogger()
logger.setLevel(logging.INFO)
TOP_N = 10

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="getresults")
    config = parser.parse_args()
    # algorithm_list ="PRL_FixEpsilon SPL_InnerMax PRL SPL BootStrap BootStrap_InnerMax Standard GaussianAug".split(' ')
    algorithm_list = "SimCLR".split(' ')
    #PRL_InnerMax

    #algorithm_list="InnerBootstrap InnerMax DoubleRobustPRL InnerPRL DoubleRobustSPL"
    # algorithm_list += " BootStrap SPL_InnerMax DoubleRobustBootstrap GaussianAug PRL_innerMax Bootstrap_InnerMax"
    # algorithm_list += " SPL PRL InnerSPL Standard"
    for i in algorithm_list:
        print(i)
    max_epoch = 100

    r2_table = []
    auc_table = []
    head = ["*{0:<13}".format(" ")] + algorithm_list
    result_csv_test = []
    result_csv_test.append(["*"] + algorithm_list)
    print(head)
    result_alg = []
    seed_list = [1]

    data = 'cifar10'
    for corrupted_type in ['patch', 'blend']:
        result = []
        LinePlot = []
        for corrupted_rate in [0.15, 0.25, 0.35, 0.45]:
            plt.clf()
            keys = ['acc_clean', 'acc_poison']
            keys_info = {f'{key}_{m}':'' for key in keys for m in ['avg', 'std']}
            for key in keys_info:
                info = 'Corrup Rate:{:<6.2f}  type:{}  Metric:{:<25s}'.format(corrupted_rate,corrupted_type, key)
                info = '{:<40s}'.format(info)
                keys_info[key] += info
            header = ' '.join(['{:<10s}'.format(alg) for alg in algorithm_list])#'{:<55s}'.format(' ') +
            print(header)
            logger.info(header)


            for algorithm in algorithm_list:
                    metrics = {key:[] for key in keys }
                    top_metrics = {key: [] for key in keys}
                    acc_clean_all_seed = []
                    acc_poison_all_seed = []
                    # if algorithm == 'PRL_SAM_OP' or algorithm == 'WeightRS':
                    #     corrupted_type = 'blend'
                    # else:
                    #     corrupted_type = 'blend'

                    #/home/liuboya2/DefenseBackDoorAttack/Result/Result/mnist/0.45/PRL/blend/1result.npy
                    for seed in seed_list:
                        # result_path = "/home/liuboya2/DefenseBackDoorAttack/Result/Result/{}/{}/{}/{}/{}result.npy".format(
                        #     data, corrupted_rate, algorithm, corrupted_type, seed)
                        result_path = "/localscratch/liuboya2/DefenseBackDoorAttack/Result/Result/{}/{}/{}/{}/{}result.npy".format(
                            data, corrupted_rate, algorithm, corrupted_type, seed)
                        #print(result_path)
                        if os.path.isfile(result_path):
                            results = np.load(result_path, allow_pickle=True)
                            results_dict = results.item()
                            for key in metrics:
                                curves = results_dict[key][:max_epoch]
                                if 'acc_clean' in key:
                                    curve_len = curves.__len__()
                                metrics[key].append(curves)
                                top_n = results_dict[key][-TOP_N:]
                                top_metrics[key].append(top_n)

                            #acc_clean = results_dict["acc_clean"][:max_epoch]
                            #acc_poison = results_dict["acc_poison"][:max_epoch]
                            #success_rate = results_dict["attack_success_rate"][:max_epoch]
                            #acc_clean_all_seed.append(acc_clean)
                            #acc_poison_all_seed.append(acc_poison)
                            #print(algorithm, 'acc poision', len(acc_poison), 'acc clean', len(acc_clean) )
                        else:
                            raise FileNotFoundError('{} not found'.format(result_path))
                    for i, key in enumerate(metrics.keys()):
                        all_curves = np.concatenate(metrics[key])
                        top_vals =  np.concatenate(top_metrics[key])
                        avg_top, std_top = np.mean(top_vals), np.std(top_vals)
                        msg = '{:20s} {:>20s} {:>20s} {:<12.2f} mean +- std: {:12.2f} {:12.4f}'.format(algorithm, key, 'corrupted rate:', corrupted_rate, float(avg_top * 100), float(std_top))
                        #alg_info = '{:<5.2f} % +- {:<5.2f} *1e-3'.format(avg_top * 100, std_top*1000)
                        avg = '{:<15.2f}'.format(avg_top * 100)
                        std = '{:<15.2f}'.format(std_top*1000)
                        keys_info[f'{key}_avg'] += avg
                        keys_info[f'{key}_std'] += std
                        #print(alg_info)
                        #logger.info(msg)

                    #acc_clean_all_seed = np.concatenate(acc_clean_all_seed)
                    #acc_poison_all_seed = np.concatenate(acc_poison_all_seed)
                    #for key in metrics:

                        plt.subplot(1, len(metrics.keys()), i+1)
                        length = len(all_curves) // len(seed_list)
                        sns.lineplot(
                            x=np.array(list(range(length)) * len(seed_list))+1,
                            y=all_curves.astype(float),
                            legend='brief',
                            label=algorithm,
                            ci="sd",
                        )
            for key in keys_info:
                print(keys_info[key])
                logger.info(keys_info[key])

            # print("plotting poison acc of {}".format(algorithm))
                    # #plt.subplot(1, 2, 2)
                    # plt.figure()
                    # sns.lineplot(
                    #    x=np.array(list(range(acc_clean.__len__())) * len(seed_list))+1,
                    #    y=acc_poison_all_seed.astype(float),
                    #    legend='brief',
                    #    label=algorithm,
                    #    ci="sd",
                    # )

            # plt.show()

            fig = plt.gcf()
            fig.set_size_inches(18.5, 10.5)
            plt.suptitle("{}: corruption rate:{}, corruption type:{}".format(data, corrupted_rate, corrupted_type))

            save_path = './Plots/{}-cp{}-{}.png'.format(data,corrupted_rate,corrupted_type)
            plt.savefig(save_path, bbox_inches='tight', pad_inches=0, dpi=100)
            print('file saved to {}'.format(save_path))
            plt.show()
    # result_alg = np.stack(result_alg)
    # sns.set_theme(style="whitegrid")
    # plt.subplot(1, 2, 1)
    # sns.lineplot(
    #     x=result_alg[:, 0],
    #     y=result_alg[:, 1].astype(float),
    #
    #     # markers=True,
    #     # style=result_alg[:, 0],
    #     # hue=result_alg[:, 0],
    #     # err_style="bars",
    # )
    # plt.xlabel("epoch")
    # plt.ylabel("acc")
    # plt.title("clean test data")
    # plt.subplot(1, 2, 2)
    # sns.lineplot(
    #     x=result_alg[:, 0],
    #     y=np.exp(result_alg[:, 2].astype(float)),
    #     # markers=True,
    #     # style=result_alg[:, 0],
    #     # hue=result_alg[:, 0],
    #     # err_style="bars",
    # )
    # # sns.lineplot(
    # #     x=result_alg[:, 3],
    # #     y=(result_alg[:, 2].astype(float)),
    # #     markers=True,
    # #     style=result_alg[:, 0],
    # #     hue=result_alg[:, 0],
    # #     err_style="bars",
    # # )
    # plt.title("adversarial attacked test data")
    # plt.xlabel("epoch")
    # plt.ylabel("acc")
    # fig = plt.gcf()
    # fig.set_size_inches(18.5, 10.5)
    # plt.show()
    # print("figure")
