import torch
import os
# dir_list=os.listdir("./total/")
# print(dir_list)
import math
import numpy as np
# ImageNet
mode='ImageNet'
save_suffix=''
alpha=0.
beta=0.
def load_and_print(savefile,attack,defense,sigma,alpha,beta):
    if os.path.exists(savefile) == False:
        print(savefile)
        return
    log_data=torch.load(savefile)
    total_log_query_point=log_data['total_log_query_point']
    total_log_l_2=log_data['total_log_l_2']
    total_log_l_inf=log_data['total_log_l_inf']
    total_log_prob=log_data['total_log_prob']
    total_log_l_2=log_data['total_log_l_2']
    total_log_acc=log_data['total_log_acc']
    total_log_adv=log_data['total_log_adv']

    total_log_l_2_query_count=log_data['total_log_l_2_query_count']
    preset_idx=[4,9,19]


    # Preprocess
    log_interval=1000
    num_images=250
    # For initialy incorrect images

    found_norm=0
    found_acc=0
    for i in range(num_images):
        for j in range(0,20):
            if math.isnan (total_log_l_2[i,j]):
                total_log_l_2[i,j]=found_norm
                total_log_acc[i,j]=found_acc
            else:
                found_norm=total_log_l_2[i,j]
                found_acc=total_log_acc[i,j]
    for i in range(num_images):
        for j in range(0,20):
            if math.isnan (total_log_l_inf[i,j]):
                total_log_l_inf[i,j]=found_norm
                #total_log_acc[i,j]=found_acc
            else:
                found_norm=total_log_l_inf[i,j]
                #found_acc=total_log_acc[i,j]
    #print(attack,defense,sigma)
    # mean_l2=np.zeros((20))
    # for k in preset_idx:
    #     success_query_l2 = []
    #     for i in range(num_images):
    #         min_l2=1e4
    #         for j in range(0,20):
    #             if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 and total_log_acc[i,j]<0.1:
    #                 if min_l2>total_log_l_2[i,j]:
    #                     min_l2=total_log_l_2[i,j]
    #         if min_l2<1e4:
    #             success_query_l2.append(min_l2)
    #     if len(success_query_l2)>0:
    #         success_query_l2=np.array(success_query_l2)
    #         print(len(success_query_l2)/250)
    #         #print(np.mean(success_query_l2),np.median(success_query_l2))
    #         mean_l2[k]=np.mean(success_query_l2)
    #     else:
    #         print('Not found')



    mean_l2=np.zeros((20))
    for k in preset_idx:
        success_query_l2 = []
        for i in range(num_images):
            min_l2=0
            for j in range(0,20):
                if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 :
                    min_l2=total_log_l_2[i,j]
            success_query_l2.append(min_l2)
        if len(success_query_l2)>0:
            success_query_l2=np.array(success_query_l2)
            # print(len(success_query_l2)/250)
            #print(np.mean(success_query_l2),np.median(success_query_l2))
            mean_l2[k]=np.mean(success_query_l2)
        else:
            print('Not found')
    if attack=='SSA':
        mean_linf=np.zeros((20))
        for k in preset_idx:
            success_query_linf = []
            for i in range(num_images):
                min_linf=0
                for j in range(0,20):
                    if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0:
                        min_linf=total_log_l_inf[i,j]
                success_query_linf.append(min_linf)
            if len(success_query_linf)>0:
                success_query_linf=np.array(success_query_linf)
                mean_linf[k]=np.mean(success_query_linf)
            else:
                print('Not found')
    # for i in range(num_images):
    #     for j in range(19, -1, -1):
    #         if total_log_query_point[i, j] > 0:
    #             total_log_l_2[i, j:20] = total_log_l_2[i, j]
    #             break

    #
    # mean_l2=np.mean(total_log_l_2,1)
    # if attack=='SSA':
    #     mean_linf=np.zeros((20))
    #     for k in preset_idx:
    #         success_query_linf = []
    #         for i in range(num_images):
    #             min_linf=1e4
    #             for j in range(0,20):
    #                 if total_log_query_point[i,j]<=(k+1)*log_interval and total_log_query_point[i,j]>0 and total_log_acc[i,j]<0.1:
    #                     if min_linf>total_log_l_inf[i,j]:
    #                         min_linf=total_log_l_inf[i,j]
    #             if min_linf<1e4:
    #                 success_query_linf.append(min_linf)
    #         if len(success_query_linf)>0:
    #             success_query_linf=np.array(success_query_linf)
    #             mean_linf[k]=np.mean(success_query_linf)
    #         else:
    #             print('Not found')
    # for i in range(num_images):
    #     for j in range(19, -1, -1):
    #         if total_log_query_point[i, j] > 0:
    #             total_log_l_2[i, j:20] = total_log_l_2[i, j]
    #             break


        # if total_log_query_point[i]%log_interval!=0:
        #     l2_norm=total_log_l_2_query_count[i]
        #     for j in range(total_log_l_2_query_count[i]//log_interval,20):
        #         total_log_l_2[i,j]=l2_norm


    # print('total_log_query_point', total_log_query_point)


    total_log_l_inf_query_count=log_data['total_log_l_inf_query_count']
    # total_log_l_2_query_count[total_log_l_2_query_count==1]=-1

    total_log_l_inf_query_count2=log_data['total_log_l_inf_query_count'].copy()
    total_log_l_inf_query_count2[total_log_l_inf_query_count==1]=-1
    # filtered_l_inf_query_count = total_log_l_inf_query_count[total_log_l_inf_query_count > 0]
    # median_log_l_inf_query_count = np.median(filtered_l_inf_query_count)
    #median_log_l_2_query_count = np.median(filtered_l_2_query_count)
    #avg_log_l_2_query_count = np.mean(filtered_l_2_query_count)
    # avg_log_l_inf_query_count = np.mean(filtered_l_inf_query_count)
    # print(np.shape(total_log_l_2))
    #
    # median_log_l_2 = np.median(total_log_l_2, axis=0)
    # avg_log_l_2 = np.mean(total_log_l_2, axis=0)
    #
    # print(defense,sigma)
    # print('avg_log_l_2', avg_log_l_2[preset_idx])
    # print('median_log_l_2', median_log_l_2[preset_idx])
    for i in [5000, 10000, 20000]:
        if attack=='SSA':
            # print('%.1f%%'%(100*np.mean((total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float))), end =' ')
            print('%.1f%%//(%.1f%%) '%(100*np.mean((total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float)),
                                      100*np.mean((total_log_l_inf_query_count2 <= i).astype(float) * (total_log_l_inf_query_count2 > 0).astype(float))), end =' ')
        else:
            print('%.1f%%' % (100 * np.mean(
                (total_log_l_2_query_count <= i).astype(float) * (total_log_l_2_query_count > 0).astype(float))),end=' ')
    #
    # print()
    # for i in [5000, 10000, 20000]:
    #     if attack == 'SSA':
    #         pass
    #     else:
    #         print('[%.2f]' % (mean_l2[i // log_interval - 1]), end=' ')
    #     # print('success_rate_l_inf:', np.mean(
    #     #    (total_log_l_inf_query_count <= i).astype(float) * (total_log_l_inf_query_count > 0).astype(float)))
    print()
    #




    # print('avg_log_l_2', avg_log_l_2[preset_idx])
    # print('median_log_l_2', median_log_l_2[preset_idx])
    #
    # median_log_l_inf = np.median(total_log_l_inf, axis=0)
    # avg_log_l_inf = np.mean(total_log_l_inf, axis=0)
    # print('avg_log_l_inf', avg_log_l_inf[preset_idx])
    # print('median_log_l_inf', median_log_l_inf[preset_idx])
    #
    # filtered_l_2_query_count = total_log_l_2_query_count[total_log_l_2_query_count > 0]
    # median_log_l_2_query_count = np.median(filtered_l_2_query_count)
    # avg_log_l_2_query_count = np.mean(filtered_l_2_query_count)
    # print('avg_log_l_2_query_count', avg_log_l_2_query_count)
    # print('median_log_l_2_query_count', median_log_l_2_query_count)
    #
    #
    # timeend = time.time()
    # print("\nTime: %.4f seconds" % (timeend - timestart))


if mode=='ImageNet':
    result_dir = 'total'
    model = 'resnet50'
    attacks=['BA','SO','HSJA','GD','SB','SBD','BD','SA','SSA']
    defenses=['gaussian','AT','rnp']
    sigmas=[0,0.01,0.001]
    for attack in attacks:
        print(attack)
        for defense in defenses:
            if defense=='gaussian' or attack=='SSA':
                for sigma in sigmas:
                    savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                        result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                    load_and_print(savefile,attack,defense,sigma,alpha,beta)
            else:
                sigma=0
                savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                    result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                load_and_print(savefile,attack,defense,sigma,alpha,beta)

else:
    result_dir = 'total'
    model = 'resnet50'
    attacks=['BA','SO','HSJA','GD','SB','SBD','BD','SA']
    defenses=['gaussian','AT','rnp']
    sigmas=[0,0.01,0.001]
    for attack in attacks:
        for defense in defenses:
            if defense=='gaussian':
                for sigma in sigmas:
                    savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                        result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                    if os.path.exists(savefile)==False:
                        print(savefile)
                else:
                    sigma=0
                    savefile = '%s/%s_%s_%s_%.3f_%.1f_%.1f_%s.pth' % (
                        result_dir, attack, model, defense, sigma, alpha, beta, save_suffix)
                    if os.path.exists(savefile)==False:
                        print(savefile)

