"""
Script to craft Gaussian noise poisons for CIFAR10.
"""
import os
import torch
import random
import argparse
import torchvision.transforms as transforms
import torchvision.datasets as dset

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--p_ratio', type=float, required=True)
    parser.add_argument('--eps', type=int, default=16, help='update bound (L-inf)')
    parser.add_argument('--data_dir', type=str, default="../data") 
    parser.add_argument('--save_dir', type=str, default="./poisons/noise")       

    return parser.parse_args()

def main():
    args = parse_args()
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    train_data = dset.CIFAR10(root=args.data_dir, train=True, download=True, transform=transform)

    n_poisons = int(len(train_data) * args.p_ratio)
    indices = random.sample(range(len(train_data)), k=n_poisons)

    print(f"Adding gaussian noise to {n_poisons} images in CIFAR10")
    clean_images = torch.stack([train_data[i][0] for i in indices], dim=0)

    std = args.eps / 255
    noise_tensor = torch.randn_like(clean_images) * std
    noise_tensor = torch.clamp(
        noise_tensor,
        -args.eps / 255,
        args.eps / 255
    )

    poisoned_images = clean_images + noise_tensor
    poisoned_images = torch.clamp(poisoned_images, 0, 1)
    poisoned_images = list(poisoned_images)
    poisoned_images = [transforms.ToPILImage()(img) for img in poisoned_images]     # Convert to PIL Images

    save_str = f"noise-cifar10-{args.p_ratio * 100:.1f}%.pth"
    save_path = os.path.join(args.save_dir, save_str)
    print("Saving poisons to {}".format(save_path))
    torch.save({
        "indices": indices,
        "poisoned_images": poisoned_images
    }, save_path)

if __name__ == "__main__":
    main()