"""
Script to generate random label flips (RLFs) for CIFAR10.
"""
import os
import torch
import random
import argparse
import torchvision.datasets as dset

CIFAR_CLASSES = 10

def generate_random_label(a: int, b: int, c: int) -> int:
    """generates random label in [a, b] that is not equal to c"""
    valid_labels = [x for x in range(a, b+1) if x != c]
    return random.choice(valid_labels)

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--p_ratio', type=float, required=True)
    parser.add_argument('--data_dir', type=str, default="../data") 
    parser.add_argument('--save_dir', type=str, default="./poisons/rlf")       

    return parser.parse_args()

def main():
    args = parse_args()
    random.seed(args.seed)

    train_data = dset.CIFAR10(root=args.data_dir, train=True, download=True)

    n_poisons = int(len(train_data) * args.p_ratio)
    indices = random.sample(range(len(train_data)), k=n_poisons)

    print(f"Randomly flipping {n_poisons} labels in CIFAR10")
    clean_labels = [train_data.targets[i] for i in indices]
    poisoned_labels = [generate_random_label(0, CIFAR_CLASSES-1, l) for l in clean_labels]
    # poisoned_labels = [11 for l in clean_labels]

    save_str = f"rlf-cifar10-{args.p_ratio * 100:.1f}%.pth"
    save_path = os.path.join(args.save_dir, save_str)
    print("Saving poisons to {}".format(save_path))
    torch.save({
        "indices": indices,
        "poisoned_labels": poisoned_labels
    }, save_path)

if __name__ == "__main__":
    main()