import numpy as np
import pickle
import os
import argparse
import torch
from torch.utils.data import DataLoader

parser = argparse.ArgumentParser()
parser.add_argument('-d', '--data', type=str, default='cifar10')
parser.add_argument('-c', '--att_class', type=int, default=0, help='class label to attack on')
args = parser.parse_args()

with open(f"{'cifar' if args.data=='cifar10' else args.data}_fixed_testds.pkl", 'rb') as fp:
    test_ds = pickle.load(fp)

test_dl = DataLoader(test_ds, batch_size=256, shuffle=False)
subset = np.array([])
for i, (x,y) in enumerate(test_dl):
    to_attack = (y.int()==args.att_class)
    subset = np.concatenate((subset, to_attack))

folder = 'attack_label_subsets'
os.makedirs(folder, exist_ok=True)
with open(f'{folder}/{args.data}_label_{args.att_class}.pkl', 'wb') as f:
    pickle.dump(subset, f)
