import argparse
import os
import torch
import numpy as np
import pandas as pd
from torchvision.models import (
    ResNet101_Weights,
    Swin_V2_B_Weights,
)
from src.data.utils import (
    get_mnist_test_loaders,
    create_mnist_dataset,
    get_cifar_test_loaders,
    create_cifar_dataset,
    get_ag_news_test_loaders,
    create_ag_news_dataset,
    build_dataloader_text,
    create_imagenet_dataset,
    get_imagenet_test_loaders,
    create_dbpedia_dataset,
    get_dbpedia_test_loaders,
)
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.data.dataset import get_indices
from src.utils.metric import JSDivergence, l2_distance_between_models
from src.utils.utility import load_model, set_seed
from transformers import AutoTokenizer
from torchvision.models import Swin_V2_B_Weights


class EvalDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        super().__init__()
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        idx = self.indices[index]
        return self.dataset[idx]


def select_dataset(dataset_name: str, train=True):
    ds = None
    if dataset_name == "mnist":
        ds = create_mnist_dataset(train=train)
    elif dataset_name == "cifar10":
        ds = create_cifar_dataset(train=train)
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        ds = create_cifar_dataset(train=train, cifar_100=True)
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_ag_news_dataset(tokenizer=tokenizer, train=train)
    elif dataset_name == "imagenet":
        ds = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=train
        )
    elif dataset_name == "imagenet_resnet":
        ds = create_imagenet_dataset(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=train
        )
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_dbpedia_dataset(tokenizer=tokenizer, train=train)
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return ds


def get_loader(dataset_name, dataset, batch_size: int, shuffle=True):
    loader = None
    if dataset_name in [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "imagenet",
        "imagenet_resnet",
    ]:
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False
        )
    elif dataset_name in ["ag_news", "dbpedia"]:
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = build_dataloader_text(
            dataset,
            tokenizer,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=False,
        )
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return loader


def get_test_loader(dataset_name, batch_size: int):
    loader = None
    if dataset_name == "mnist":
        loader = get_mnist_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "cifar10":
        loader = get_cifar_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        loader = get_cifar_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            cifar_100=True,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = get_ag_news_test_loaders(
            tokenizer=tokenizer,
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "imagenet":
        loader = get_imagenet_test_loaders(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(),
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "imagenet_resnet":
        loader = get_imagenet_test_loaders(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(),
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = get_dbpedia_test_loaders(
            tokenizer=tokenizer,
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return loader


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    method_choices = [
        "uul",
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
    ]
    parser = argparse.ArgumentParser("Evaluating Unlearning Methods")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help="Specify the dataset to test unlearning. Choices are: [mnist, cifar10, cifar100, ag_news, imagenet, dbpedia]",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="*",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-ckpts",
        "--checkpoints",
        type=str,
        default=None,
        help="Specify the path to folder with unlearned model for evaluation",
    )
    parser.add_argument(
        "-retr_ckpt",
        "--retrained_checkpoint",
        type=str,
        default=None,
        help="Specify the path to retrained model",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "eval_ind"),
        help="Specify the path to save evaluation results",
    )
    parser.add_argument(
        "-m",
        "--method",
        type=str,
        choices=method_choices,
        required=True,
        help=f"Specify the unlearning method. Choices are: {method_choices}",
    )
    args = parser.parse_args()

    os.makedirs(os.path.join(args.save_path, args.method), exist_ok=True)
    set_seed(args.seed)
    train_ds = select_dataset(dataset_name=args.dataset, train=True)
    test_ds = select_dataset(dataset_name=args.dataset, train=False)

    test_loader = get_loader(args.dataset, test_ds, args.batch_size, shuffle=False)
    forget_indices, retain_indices = get_indices(
        train_ds,
        forget_labels=args.forget_labels,
        index_file=args.index_file,
        seed=args.seed,
    )
    forget_dataset = EvalDataset(train_ds, forget_indices)
    forget_loader = get_loader(
        args.dataset, forget_dataset, batch_size=args.batch_size, shuffle=False
    )
    retain_dataset = EvalDataset(train_ds, retain_indices)
    retain_loader = get_loader(
        args.dataset, retain_dataset, batch_size=args.batch_size, shuffle=False
    )

    retrained_model = select_model(
        dataset_name=args.dataset, checkpoint=args.retrained_checkpoint
    )

    js_scores = []
    l2_inds = []
    device = torch.device(f"cuda:{args.gpu}")
    for ckpt in os.listdir(args.checkpoints):
        if ckpt.endswith(".pt"):
            target_model = select_model(
                args.dataset, checkpoint=os.path.join(args.checkpoints, ckpt)
            )
            js = JSDivergence(retrained_model, target_model, device=device)
            js_score = js.calculate(test_loader)
            js_scores.append(js_score)
            print("JS Score", js_score)
            l2_score = l2_distance_between_models(
                model_a=target_model, model_b=retrained_model
            )
            l2_inds.append(l2_score)
            print("L2 Norm between models", l2_score)

    avg_jsd = np.mean(js_scores)
    avg_l2 = np.mean(l2_inds)
    var_jsd = np.var(js_scores)
    var_l2 = np.var(l2_inds)
    print("Results: ")
    print("JSD Score Average: ", avg_jsd)
    print("JSD Score Variance: ", var_jsd)
    print("L2 Score Average: ", avg_l2)
    print("L2 Score Variance: ", var_l2)

    df_dict = {
        "jsd_average": [avg_jsd],
        "jsd_variance": [var_jsd],
        "l2_average": [avg_l2],
        "l2_variance": [var_l2],
    }

    filename = f"{args.dataset}.csv"
    save_path = os.path.join(args.save_path, args.method, filename)
    df = pd.DataFrame.from_dict(df_dict)
    df.to_csv(save_path, index=False)
    print("Results saved at: ", save_path)
