import os
import torch
import argparse
from utils import default_args, tools, supervisor, tinyimagenet
import torchvision.transforms as transforms
import random
from PIL import Image
from torchvision.utils import save_image
import config

parser = argparse.ArgumentParser()
# parser.add_argument('-poison_type', type=str,  required=False,
#                     choices=default_args.parser_choices['poison_type'],
#                     default=default_args.parser_default['poison_type'])
# parser.add_argument('-poison_rate', type=float,  required=False,
#                     choices=default_args.parser_choices['poison_rate'],
#                     default=default_args.parser_default['poison_rate'])
# parser.add_argument('-alpha', type=float,  required=False,
#                     default=default_args.parser_default['alpha'])
# parser.add_argument('-trigger', type=str, required=False,
#                     default=None)
# args = parser.parse_args()
# args.dataset = 'tinyimagenet'
parser.add_argument('-poison_type', type=str,  required=False,
                    choices=default_args.parser_choices['poison_type'],
                    default='badnet')
parser.add_argument('-poison_rate', type=float,  required=False,
                    choices=default_args.parser_choices['poison_rate'],
                    default='0.200')
parser.add_argument('-alpha', type=float,  required=False,
                    default='0.200')
parser.add_argument('-trigger', type=str, required=False,
                    default=None)
args = parser.parse_args()
args.dataset = 'tinyimagenet'


tools.setup_seed(0)

if args.trigger is None:
    args.trigger = config.trigger_default[args.dataset][args.poison_type]

if args.poison_type not in ['none', 'badnet', 'trojan', 'blend']:
    raise NotImplementedError('%s is not implemented on tinyimagenet' % args.poison_type)

if args.poison_type == 'none':
    args.poison_rate = 0

if not os.path.exists(os.path.join('poisoned_train_set', 'tinyimagenet')):
    os.mkdir(os.path.join('poisoned_train_set', 'tinyimagenet'))

poison_set_dir = supervisor.get_poison_set_dir(args)
if not os.path.exists(poison_set_dir):
    os.mkdir(poison_set_dir)

poison_imgs_dir = os.path.join(poison_set_dir, 'data')
if not os.path.exists(poison_imgs_dir):
    os.mkdir(poison_imgs_dir)

num_imgs = 100_000 # size of tinyimagenet training set


# random sampling
id_set = list(range(0,num_imgs))
random.shuffle(id_set)
num_poison = int(num_imgs * args.poison_rate)
poison_indices = id_set[:num_poison]
poison_indices.sort() # increasing order


# train_set_dir = '/shadowdata/xiangyu/tinyimagenet_256/train'
# train_set_dir = '/scratch/gpfs/DATASETS/tinyimagenet/ilsvrc_2012_classification_localization/train'
train_set_dir = os.path.join(config.tinyimagenet_dir, "train")

classes, class_to_idx, idx_to_class = tinyimagenet.find_classes(train_set_dir)
num_imgs, img_id_to_path, img_labels = tinyimagenet.assign_img_identifier(train_set_dir, classes)
print('num_imgs = %d' % num_imgs)
print('num_poison = %d' % num_poison)
print('img_id_to_path[:5] = %s' % img_id_to_path[:5])
print('img_labels[:5] = %s' % img_labels[:5]) 

transform_to_tensor = transforms.Compose([
    transforms.ToTensor(),
])

# poison_transform = tinyimagenet.get_poison_transform_for_tinyimagenet(args.poison_type)
poison_transform = supervisor.get_poison_transform(poison_type=args.poison_type, dataset_name=args.dataset,
                                                    target_class=config.target_class[args.dataset], trigger_transform=transform_to_tensor,
                                                    is_normalized_input=True,
                                                    alpha=args.alpha,
                                                    trigger_name=args.trigger, args=args)


cnt = 0
tot = len(poison_indices)
print('# poison samples = %d' % tot)
for pid in poison_indices:
    cnt+=1
    img_path = os.path.join(train_set_dir, img_id_to_path[pid])

    ori_img = transform_to_tensor(Image.open(img_path).convert("RGB"))
    poison_img, _ = poison_transform.transform(ori_img, torch.zeros(ori_img.shape[0]))

    cls_path = os.path.join(poison_imgs_dir, idx_to_class[img_labels[pid]])
    if not os.path.exists(cls_path):
        os.mkdir(cls_path)

    dst_path = os.path.join(poison_imgs_dir, img_id_to_path[pid])
    save_image(poison_img, dst_path)
    print('save [%d/%d]: %s' % (cnt,tot, dst_path))



poison_indices_path = os.path.join(poison_set_dir, 'poison_indices')
torch.save(poison_indices, poison_indices_path)
print('[Generate Poisoned Set] Save %s' % poison_indices_path)