import torch
import os
# dir_list=os.listdir("./total/")
# print(dir_list)
import math
import numpy as np
# ImageNet
mode='cifar'
save_suffix=''
alpha=0.
beta=0.
def load_and_print(savefile,attack,defense,sigma,alpha,beta):
    if os.path.exists(savefile) == False:
        print(savefile)
        return
    # print(attack,defense, sigma)
    log_data=torch.load(savefile)
    total_log_query_point=log_data['total_log_query_point']
    total_log_l_2=log_data['total_log_l_2']
    total_log_l_inf=log_data['total_log_l_inf']
    total_log_prob=log_data['total_log_prob']
    total_log_l_2=log_data['total_log_l_2']
    total_log_acc=log_data['total_log_acc']
    total_log_adv=log_data['total_log_adv']

    total_log_l_2_query_count=log_data['total_log_l_2_query_count']

    total_log_ne_count=log_data['total_log_ne_count']
    total_log_query_count=log_data['total_log_query_count']
    preset_idx=[9,24,49]
    # Preprocess
    total_queries=10000
    log_interval=200
    num_images=1000
    num_logs=total_queries//log_interval
    # For initialy incorrect images

    found_norm=0
    found_acc=0
    for i in range(num_images):
        for j in range(0,num_logs):
            if math.isnan (total_log_l_2[i,j]):
                total_log_l_2[i,j]=found_norm
                total_log_acc[i,j]=found_acc
            else:
                found_norm=total_log_l_2[i,j]
                found_acc=total_log_acc[i,j]
    for i in range(num_images):
        for j in range(0,num_logs):
            if math.isnan (total_log_l_inf[i,j]):
                total_log_l_inf[i,j]=found_norm
                #total_log_acc[i,j]=found_acc
            else:
                found_norm=total_log_l_inf[i,j]
                #found_acc=total_log_acc[i,j]
    #print(attack,defense,sigma)
    mean_l2=np.zeros((num_logs))
    for k in preset_idx:
        success_query_l2 = []
        for i in range(num_images):
            min_l2=1e4
            for j in range(0,num_logs):
                if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 and total_log_acc[i,j]<0.1:
                    if min_l2>total_log_l_2[i,j]:
                        min_l2=total_log_l_2[i,j]
            if min_l2<1e4:
                success_query_l2.append(min_l2)
        if len(success_query_l2)>0:
            success_query_l2=np.array(success_query_l2)
            #print(np.mean(success_query_l2),np.median(success_query_l2))
            mean_l2[k]=np.mean(success_query_l2)
        else:
            print('Not found')

    #
    # mean_l2=np.zeros((num_logs))
    # for k in preset_idx:
    #     success_query_l2 = []
    #     for i in range(num_images):
    #         min_l2=0
    #         for j in range(0,num_logs):
    #             if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 :
    #                 min_l2=total_log_l_2[i,j]
    #         success_query_l2.append(min_l2)
    #     if len(success_query_l2)>0:
    #         success_query_l2=np.array(success_query_l2)
    #         # print(len(success_query_l2)/250)
    #         #print(np.mean(success_query_l2),np.median(success_query_l2))
    #         mean_l2[k]=np.mean(success_query_l2)
    #     else:
    #         print('Not found')
    # if attack=='SSA':
    #     mean_linf=np.zeros((num_logs))
    #     for k in preset_idx:
    #         success_query_linf = []
    #         for i in range(num_images):
    #             min_linf=0
    #             for j in range(0,num_logs):
    #                 if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0:
    #                     min_linf=total_log_l_inf[i,j]
    #             success_query_linf.append(min_linf)
    #         if len(success_query_linf)>0:
    #             success_query_linf=np.array(success_query_linf)
    #             mean_linf[k]=np.mean(success_query_linf)
    #         else:
    #             print('Not found')
    # for i in range(num_images):
    #     for j in range(19, -1, -1):
    #         if total_log_query_point[i, j] > 0:
    #             total_log_l_2[i, j:20] = total_log_l_2[i, j]
    #             break

    #
    # mean_l2=np.mean(total_log_l_2,1)
    # if attack=='SSA':
    #     mean_linf=np.zeros((20))
    #     for k in preset_idx:
    #         success_query_linf = []
    #         for i in range(num_images):
    #             min_linf=1e4
    #             for j in range(0,20):
    #                 if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 and total_log_acc[i,j]<0.1:
    #                     if min_linf>total_log_l_inf[i,j]:
    #                         min_linf=total_log_l_inf[i,j]
    #             if min_linf<1e4:
    #                 success_query_linf.append(min_linf)
    #         if len(success_query_linf)>0:
    #             success_query_linf=np.array(success_query_linf)
    #             mean_linf[k]=np.mean(success_query_linf)
    #         else:
    #             print('Not found')
    # for i in range(num_images):
    #     for j in range(19, -1, -1):
    #         if total_log_query_point[i, j] > 0:
    #             total_log_l_2[i, j:20] = total_log_l_2[i, j]
    #             break


        # if total_log_query_point[i]%log_interval!=0:
        #     l2_norm=total_log_l_2_query_count[i]
        #     for j in range(total_log_l_2_query_count[i]//log_interval,20):
        #         total_log_l_2[i,j]=l2_norm


    # print('total_log_query_point', total_log_query_point)


    total_log_l_inf_query_count=log_data['total_log_l_inf_query_count']
    # total_log_l_2_query_count[total_log_l_2_query_count==1]=-1

    total_log_ne_ratio=total_log_ne_count/total_log_query_count
    total_log_l_inf_query_count2=log_data['total_log_l_inf_query_count'].copy()
    total_log_l_inf_query_count2[total_log_l_inf_query_count==1]=-1
    # filtered_l_inf_query_count = total_log_l_inf_query_count[total_log_l_inf_query_count > 0]
    # median_log_l_inf_query_count = np.median(filtered_l_inf_query_count)
    #median_log_l_2_query_count = np.median(filtered_l_2_query_count)
    #avg_log_l_2_query_count = np.mean(filtered_l_2_query_count)
    # avg_log_l_inf_query_count = np.mean(filtered_l_inf_query_count)
    # print(np.shape(total_log_l_2))
    #
    # median_log_l_2 = np.median(total_log_l_2, axis=0)
    # avg_log_l_2 = np.mean(total_log_l_2, axis=0)
    #
    # print(defense,sigma)
    # print('avg_log_l_2', avg_log_l_2[preset_idx])
    # print('median_log_l_2', median_log_l_2[preset_idx])
    print(attack,defense,sigma)
    for i in [2000, 5000, 10000]:
        if attack=='SSA':
            # print('%.1f%%'%(100*np.mean((total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float))), end =' ')
            print('%.1f%%//(%.1f%%) '%(100*np.mean((total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float)),
                                      100*np.mean((total_log_l_inf_query_count2 <= i).astype(float) * (total_log_l_inf_query_count2 > 0).astype(float))), end =' ')
        else:
            print('%.1f%%' % (100 * np.mean(
                (total_log_l_2_query_count <= i).astype(float) * (total_log_l_2_query_count > 0).astype(float))),end=' ')

    # print()
    #
    print('[%.3f]' % (np.mean(total_log_ne_ratio)), end=' ')
    # for i in [2000, 5000, 10000]:
    #     if attack == 'SSA':
    #         pass
    #     else:
    #         print('[%.2f]' % (mean_l2[i // log_interval - 1]), end=' ')


        # print('success_rate_l_inf:', np.mean(
        #    (total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float)))
    print()
    #




    # print('avg_log_l_2', avg_log_l_2[preset_idx])
    # print('median_log_l_2', median_log_l_2[preset_idx])
    #
    # median_log_l_inf = np.median(total_log_l_inf, axis=0)
    # avg_log_l_inf = np.mean(total_log_l_inf, axis=0)
    # print('avg_log_l_inf', avg_log_l_inf[preset_idx])
    # print('median_log_l_inf', median_log_l_inf[preset_idx])
    #
    # filtered_l_2_query_count = total_log_l_2_query_count[total_log_l_2_query_count > 0]
    # median_log_l_2_query_count = np.median(filtered_l_2_query_count)
    # avg_log_l_2_query_count = np.mean(filtered_l_2_query_count)
    # print('avg_log_l_2_query_count', avg_log_l_2_query_count)
    # print('median_log_l_2_query_count', median_log_l_2_query_count)
    #
    #
    # timeend = time.time()
    # print("\nTime: %.4f seconds" % (timeend - timestart))


if mode=='ImageNet':
    result_dir = 'total'
    model = 'resnet50'
    attacks=['BA','SO','HSJA','GD','SB','SBD','BD','SA','SSA']
    defenses=['gaussian','AT','rnp']
    sigmas=[0,0.01,0.001]
    for attack in attacks:
        print(attack)
        for defense in defenses:
            if defense=='gaussian' or attack=='SSA':
                for sigma in sigmas:
                    savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                        result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                    load_and_print(savefile,attack,defense,sigma,alpha,beta)
            else:
                sigma=0
                savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                    result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                load_and_print(savefile,attack,defense,sigma,alpha,beta)

else:
    result_dir = 'total'
    model = 'resnet20'
    attacks=['BA','SO','HSJA','GD']
    defenses=['gaussian','RSE','PNI']
    sigmas=[0,0.01,0.001]

    for attack in attacks:
        print(attack)
        for defense in defenses:
            if defense=='gaussian' or attack=='SSA':
                for sigma in sigmas:
                    savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                        result_dir, attack, model, defense, sigma, alpha, beta,
                        save_suffix)
                    load_and_print(savefile,attack,defense,sigma,alpha,beta)
            else:
                sigma=0
                savefile = '%s/CIFAR_%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                    result_dir, attack, model, defense, sigma, alpha, beta,
                    save_suffix)
                load_and_print(savefile, attack, defense, sigma, alpha, beta)
