import argparse
import os
import time
import torch
from torch.nn import CrossEntropyLoss
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
from torchvision.models import (
    ResNet101_Weights,
    Swin_V2_B_Weights,
)
from src.data.dataset import DiscriminativeDataset
from src.data.utils import (
    get_mnist_test_loaders,
    create_mnist_dataset,
    get_cifar_test_loaders,
    create_cifar_dataset,
    get_ag_news_test_loaders,
    create_ag_news_dataset,
    build_dataloader_text,
    create_imagenet_dataset,
    get_imagenet_test_loaders,
    create_dbpedia_dataset,
    get_dbpedia_test_loaders,
)
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.utils.test import test_unlearn_classification
from src.utils.utility import load_model, set_seed
from src.utils.utility import set_seed, split_dataset, load_model
from src.utils.train import train_classification_model


def select_dataset(
    dataset_name: str,
    forget_labels: None | list,
    index_file: None | str,
    test_index_file: None | str,
    seed: int,
    batch_size: int,
):
    if dataset_name == "mnist":
        ds = DiscriminativeDataset(
            create_mnist_dataset(train=True),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        dataloaders = get_mnist_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name == "cifar10":
        ds = DiscriminativeDataset(
            create_cifar_dataset(train=True),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        dataloaders = get_cifar_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        ds = DiscriminativeDataset(
            create_cifar_dataset(train=True, cifar_100=True),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        dataloaders = get_cifar_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            cifar_100=True,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = DiscriminativeDataset(
            create_ag_news_dataset(tokenizer=tokenizer, train=True),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = build_dataloader_text(
            train_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
        )
        val_loader = build_dataloader_text(
            val_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
        dataloaders = get_ag_news_test_loaders(
            tokenizer=tokenizer,
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name == "imagenet":
        ds = DiscriminativeDataset(
            create_imagenet_dataset(
                preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=True
            ),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        dataloaders = get_imagenet_test_loaders(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(),
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name == "imagenet_resnet":
        ds = DiscriminativeDataset(
            create_imagenet_dataset(
                preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=True
            ),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = torch.utils.data.DataLoader(
            dataset=train_ds, batch_size=batch_size, shuffle=True, drop_last=False
        )
        val_loader = torch.utils.data.DataLoader(
            dataset=val_ds, batch_size=batch_size, shuffle=False, drop_last=False
        )
        dataloaders = get_imagenet_test_loaders(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(),
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = DiscriminativeDataset(
            create_dbpedia_dataset(tokenizer=tokenizer, train=True),
            forget_labels=forget_labels,
            num_retains=-1,
            num_forgets=None,
            index_file=index_file,
            seed=seed,
        )
        train_ds, val_ds = split_dataset(ds, val_split=0.1, seed=42)
        train_loader = build_dataloader_text(
            train_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=True,
            drop_last=False,
        )
        val_loader = build_dataloader_text(
            val_ds,
            tokenizer,
            batch_size=batch_size,
            shuffle=False,
            drop_last=False,
        )
        dataloaders = get_dbpedia_test_loaders(
            tokenizer=tokenizer,
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update({"train_loader": train_loader})
        dataloaders.update({"val_loader": val_loader})
        return dataloaders
    else:
        raise Exception(f"dataset {dataset_name} is not defined")


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    parser = argparse.ArgumentParser("Testing Unlearning")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help="Specify the dataset to test unlearning. Choices are: [mnist, cifar10, cifar100, ag_news, imagenet, dbpedia]",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="*",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-epochs",
        "--num_epochs",
        type=int,
        default=10000,
        help="Number of Epochs",
    )
    parser.add_argument(
        "-p",
        "--patience",
        type=int,
        default=5,
        help="Early stopping patience level",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Learning rate for training",
    )
    parser.add_argument(
        "-wd",
        "--weight_decay",
        type=float,
        default=1e-6,
        help="L2 weight decay",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-lrs",
        "--lr_scheduler",
        type=str,
        choices=["cswr", "rlrp"],
        default=None,
        help="Specify the learning rate scheduler to use. Choices are: [cswr, rlrp]",
    )
    parser.add_argument(
        "-ckpt",
        "--checkpoint",
        type=str,
        default=None,
        help="Specify the path to trained model checkpoint",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "retrained"),
        help="Specify the path to save checkpoint",
    )
    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(os.path.join(args.save_path, "models"), exist_ok=True)
    os.makedirs(os.path.join(args.save_path, "results"), exist_ok=True)

    unlearn_data = {
        "full": {"loss": [], "acc": []},
        "retain": {"loss": [], "acc": []},
        "forget": {"loss": [], "acc": []},
    }

    set_seed(args.seed)

    datasets = select_dataset(
        args.dataset,
        args.forget_labels,
        index_file=args.index_file,
        test_index_file=args.test_index_file,
        seed=args.seed,
        batch_size=args.batch_size,
    )

    model = select_model(args.dataset, checkpoint=args.checkpoint)

    start_time = time.time()
    file_name = f"{args.dataset}_retrained"
    save_path = os.path.join(
        args.save_path,
        "models",
        f"{file_name}.pt",
    )

    train_classification_model(
        model=model,
        train_loader=datasets["train_loader"],
        loss_fn=CrossEntropyLoss(),
        device=torch.device(f"cuda:{args.gpu}"),
        lr=args.learning_rate,
        num_epochs=args.num_epochs,
        val_loader=datasets["val_loader"],
        lr_scheduler=args.lr_scheduler,
        patience=args.patience,
        weight_decay=args.weight_decay,
        save_path=save_path,
    )
    end_time = time.time()
    time_taken = end_time - start_time

    model = load_model(model, save_path)

    test_loss, accuracy = test_unlearn_classification(
        model=model,
        device=torch.device(f"cuda:{args.gpu}"),
        test_loader=datasets["test_full_loader"],
    )
    unlearn_data["full"]["loss"].append(test_loss)
    unlearn_data["full"]["acc"].append(accuracy)

    if datasets["test_retain_loader"] is not None:
        test_loss, accuracy = test_unlearn_classification(
            model=model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_retain_loader"],
        )
        unlearn_data["retain"]["loss"].append(test_loss)
        unlearn_data["retain"]["acc"].append(accuracy)
    else:
        unlearn_data["retain"]["loss"].append(0)
        unlearn_data["retain"]["acc"].append(0)

    if datasets["test_forget_loader"] is not None:
        test_loss, accuracy = test_unlearn_classification(
            model=model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_forget_loader"],
        )
        unlearn_data["forget"]["loss"].append(test_loss)
        unlearn_data["forget"]["acc"].append(accuracy)
    else:
        unlearn_data["forget"]["loss"].append(0)
        unlearn_data["forget"]["acc"].append(0)

    with open(
        os.path.join(
            args.save_path,
            "results",
            f"{file_name}-time.txt",
        ),
        "w",
    ) as handle:
        handle.write(f"Time (mean): {time_taken}s  Time (var): {0}")

    df_dict = {
        "dataset": [],
        "cross_entropy-mean": [],
        "cross_entropy-std": [],
        "accuracy-mean": [],
        "accuracy-std": [],
    }

    for i in ["full", "retain", "forget"]:
        df_data = unlearn_data[i]
        df_dict["dataset"].append(i)
        df_dict["cross_entropy-mean"].append(np.mean(df_data["loss"]))
        df_dict["cross_entropy-std"].append(np.std(df_data["loss"]))
        df_dict["accuracy-mean"].append(np.mean(df_data["acc"]))
        df_dict["accuracy-std"].append(np.std(df_data["acc"]))

    df = pd.DataFrame.from_dict(df_dict)
    df.to_csv(
        os.path.join(
            args.save_path,
            "results",
            f"{file_name}.csv",
        ),
        index=False,
    )
    print(f"Results saved in {file_name}.csv")

    print("==" * 20)
    print(f"Successfully Trained model on {args.dataset}")
    print("==" * 20)
