import argparse
import os
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer
from torchvision.models import Swin_V2_B_Weights
from src.data.utils import (
    create_mnist_dataset,
    create_cifar_dataset,
    create_ag_news_dataset,
    create_imagenet_dataset,
    create_dbpedia_dataset,
)


def select_dataset(dataset_name: str, train=True):
    ds = None
    if dataset_name == "mnist":
        ds = create_mnist_dataset(train=train)
    elif dataset_name == "cifar10":
        ds = create_cifar_dataset(train=train)
    elif dataset_name == "cifar100":
        ds = create_cifar_dataset(train=train, cifar_100=True)
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_ag_news_dataset(tokenizer=tokenizer, train=train)
    elif dataset_name == "imagenet":
        ds = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=train
        )
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_dbpedia_dataset(tokenizer=tokenizer, train=train)
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return ds


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "ag_news",
        "imagenet",
        "dbpedia",
    ]
    parser = argparse.ArgumentParser(
        "Generate Index file for faster processing of datasets"
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help=f"Specify the dataset to test unlearning. Choices are: {dataset_choices}",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=None,
        help="Specify the folder path to save dataset index file",
    )

    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)

    for train in [True, False]:
        index_filename = f"{args.dataset}-{'train' if train else 'test'}-index.csv"
        dataset = select_dataset(args.dataset, train=train)

        df_dict = {"index": [], "class": []}
        for i, (_, label) in tqdm(enumerate(dataset), total=len(dataset)):
            df_dict["index"].append(i)
            df_dict["class"].append(label)

        df = pd.DataFrame.from_dict(df_dict)
        df.to_csv(os.path.join(args.save_path, index_filename), index=False)
        print("Index saved at: ", os.path.join(args.save_path, index_filename))
