import argparse
import os
import torch
from src.unlearn.salun import generate_mask
from unlearning import select_dataset, select_model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    parser = argparse.ArgumentParser("Generate Saliency Masks for SalUn method")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help=f"Specify the dataset to test unlearning. Choices are: {dataset_choices}",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="*",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-nr",
        "--num_retains",
        type=int,
        default=2000,
        help="Number of retain samples",
    )
    parser.add_argument(
        "-nf",
        "--num_forgets",
        type=int,
        default=1000,
        help="Number of forget samples",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Learning rate for training",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-momentum",
        "--momentum",
        type=float,
        default=0.1,
        help="Momentum Value for training",
    )
    parser.add_argument(
        "-wd",
        "--weight_decay",
        type=float,
        default=1e-3,
        help="L2 weight decay",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-ckpt",
        "--checkpoint",
        type=str,
        default=None,
        help="Specify the path to trained model checkpoint",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "masks"),
        help="Specify the directory path to save masks",
    )
    args = parser.parse_args()
    save_path_dir = os.path.join(args.save_path, args.dataset)
    os.makedirs(save_path_dir, exist_ok=True)
    device = torch.device(f"cuda:{args.gpu}")

    datasets = select_dataset(
        args.dataset,
        "salun",
        args.forget_labels,
        args.num_retains,
        args.num_forgets,
        index_file=args.index_file,
        test_index_file=args.test_index_file,
        batch_size=args.batch_size,
        seed=args.seed,
    )

    model = select_model(dataset_name=args.dataset, checkpoint=args.checkpoint)

    generate_mask(
        forget_loader=datasets["forget_loader"],
        model=model,
        device=device,
        lr=args.learning_rate,
        momentum=args.momentum,
        weight_decay=args.weight_decay,
        save_path_dir=save_path_dir,
    )
