import argparse
import os
import torch
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer
from torchvision.models import (
    ResNet101_Weights,
    Swin_V2_B_Weights,
)
from src.data.utils import (
    create_mnist_dataset,
    create_cifar_dataset,
    create_ag_news_dataset,
    build_dataloader_text,
    create_imagenet_dataset,
    create_dbpedia_dataset,
)
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.utils.utility import set_seed, split_dataset, load_model
from src.utils.train import train_classification_model
from src.utils.test import test_classification_model


def select_dataset(
    dataset_name: str,
    batch_size: int,
):
    train_ds = None
    train_loader = None
    test_ds = None
    test_loader = None
    val_ds = None
    val_loader = None

    if dataset_name == "mnist":
        ds = create_mnist_dataset(train=True)
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        test_ds = create_mnist_dataset(train=False)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    elif dataset_name == "cifar10":
        ds = create_cifar_dataset(train=True)
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        test_ds = create_cifar_dataset(train=False)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        ds = create_cifar_dataset(train=True, cifar_100=True)
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        test_ds = create_cifar_dataset(train=False, cifar_100=True)
        test_loader = torch.utils.data.DataLoader(
            dataset=test_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_ag_news_dataset(tokenizer=tokenizer, train=True)
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = build_dataloader_text(
            train_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
        )
        val_loader = build_dataloader_text(
            val_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
        test_ds = create_ag_news_dataset(tokenizer=tokenizer, train=False)
        test_loader = build_dataloader_text(
            test_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )

    elif dataset_name == "imagenet":
        ds = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=True
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        test_ds = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    elif dataset_name == "imagenet_resnet":
        ds = create_imagenet_dataset(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=True
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        test_ds = create_imagenet_dataset(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=False
        )
        test_loader = torch.utils.data.DataLoader(
            test_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_dbpedia_dataset(tokenizer=tokenizer, train=True)
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = build_dataloader_text(
            train_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
        )
        val_loader = build_dataloader_text(
            val_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
        test_ds = create_dbpedia_dataset(tokenizer=tokenizer, train=False)
        test_loader = build_dataloader_text(
            test_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return train_ds, train_loader, val_ds, val_loader, test_ds, test_loader


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    parser = argparse.ArgumentParser("Testing Unlearning")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help=f"Specify the dataset to test unlearning. Choices are: {dataset_choices}",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-p",
        "--patience",
        type=int,
        default=5,
        help="Early stopping patience level",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Learning rate for training",
    )
    parser.add_argument(
        "-wd",
        "--weight_decay",
        type=float,
        default=1e-6,
        help="L2 weight decay",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-lrs",
        "--lr_scheduler",
        type=str,
        choices=["cswr", "rlrp"],
        default=None,
        help="Specify the learning rate scheduler to use. Choices are: [cswr, rlrp]",
    )
    parser.add_argument(
        "-ckpt",
        "--checkpoint",
        type=str,
        default=None,
        help="Specify the path to trained model checkpoint",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "in_models"),
        help="Specify the path to save checkpoint",
    )
    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)

    set_seed(args.seed)

    train_ds, train_loader, val_ds, val_loader, test_ds, test_loader = select_dataset(
        args.dataset,
        args.batch_size,
    )

    model = select_model(args.dataset, checkpoint=args.checkpoint)

    file_name = f"{args.dataset}_scratch"
    save_path = os.path.join(
        args.save_path,
        f"{file_name}.pt",
    )

    train_classification_model(
        model,
        train_loader,
        val_loader,
        loss_fn=CrossEntropyLoss(),
        device=torch.device(f"cuda:{args.gpu}"),
        lr=args.learning_rate,
        num_epochs=10000,
        lr_scheduler=args.lr_scheduler,
        patience=args.patience,
        weight_decay=args.weight_decay,
        save_path=save_path,
    )

    model = load_model(model, save_path)

    test_classification_model(
        model=model,
        device=torch.device(f"cuda:{args.gpu}"),
        test_loader=test_loader,
        loss_fn=CrossEntropyLoss(),
    )

    print("==" * 20)
    print(f"Successfully Trained model on {args.dataset}")
    print("==" * 20)
