
import os
import random
import pickle
from tqdm import tqdm
import torch


emb_fold = "/data//mm-safety/visual_constrained_eps_32_hh_rlhf_embedding"
emb_files = os.listdir(emb_fold)
emb_files = sorted(emb_files)

all_emb = []
for emb_file in tqdm(emb_files):
    with open(os.path.join(emb_fold, emb_file), "rb") as f:
        all_emb.append(pickle.load(f))

random.seed(0)
random.shuffle(all_emb)
train_len = int(len(all_emb)*0.8)
train_emb = all_emb[:train_len]
test_emb = all_emb[train_len:]

train_emb = torch.stack(train_emb).float()
attack_emb = train_emb[:, 0, ...].clone()
raw_emb = train_emb[:, 2, ...].clone()
train_emb = train_emb[:, 0, ...] - train_emb[:, 2, ...]
direction = train_emb.mean(dim=0)
mag = torch.norm(direction, dim=-1)

margin_attack = (torch.einsum('bpd,pd->bp', attack_emb, direction) / mag.unsqueeze(0)).mean(dim=0)
margin_emb = (torch.einsum('bpd,pd->bp', raw_emb, direction) / mag.unsqueeze(0)).mean(dim=0)
margin = (margin_attack + margin_emb) / 2

total, acc1, acc2 = 0, 0, 0
for t_emb in test_emb:
    projection = torch.einsum('bpd,pd->bp', t_emb.float(), direction) / mag.unsqueeze(0)  # bp
    judge_0 = projection[0] - margin
    judge_1 = margin - projection[2]
    if judge_0.sum() > 0:
        acc1 += 1
    if judge_1.sum() > 0:
        acc2 += 1
    total += 1
print(f"train: {train_len}, test: {len(test_emb)}")
print(f"attacked acc: {acc1 / total * 100:.2f}%; raw_acc: {acc2 / total * 100:.2f}%")
