
import os
import pickle
from tqdm import tqdm
import torch
import argparse


parser = argparse.ArgumentParser()
parser.add_argument("--direction_file", type=str)
parser.add_argument("--test_fold", type=str)
args = parser.parse_args()

direction = pickle.load(open(args.direction_file, "rb"))
direction, mag, margin = direction['direction'], direction['mag'], direction['margin']

test_emb = []
emb_fold = args.test_fold
emb_files = os.listdir(emb_fold)
emb_files = sorted(emb_files)
for emb_file in tqdm(emb_files):
    with open(os.path.join(emb_fold, emb_file), "rb") as f:
        test_emb.append(pickle.load(f))

acc0, acc1, acc2 = 0, 0, 0
for t_emb in test_emb:
    # [attack, raw, random_noise]
    # 计算该样本在direction上的投影（除以norm），以patch为单位
    projection = torch.einsum('bpd,pd->bp', t_emb.float(), direction) / mag.unsqueeze(0)  # 3,p
    judge_0 = projection[0] - margin
    judge_1 = projection[1] - margin
    judge_2 = projection[2] - margin
    # 综合判断（投票）
    if judge_0.mean() > 0:  # attack 图片二分类，比margin要大为对，以每个patch为单位
        acc0 += 1
    if judge_1.mean() < 0:  # raw 图片二分类，比margin要小为对，以每个patch为单位
        acc1 += 1
    if judge_2.mean() < 0:  # random 图片二分类，比margin要小为对，以每个patch为单位
        acc2 += 1

print(f"acc: {(acc0+acc1+acc2) / (3 * len(test_emb)) * 100:.2f}%")
