import argparse
import os
import random
import torch
import numpy as np
import pandas as pd
from torchvision.models import (
    ResNet101_Weights,
    Swin_V2_B_Weights,
)
from src.data.utils import (
    get_mnist_test_loaders,
    create_mnist_dataset,
    get_cifar_test_loaders,
    create_cifar_dataset,
    get_ag_news_test_loaders,
    create_ag_news_dataset,
    build_dataloader_text,
    create_imagenet_dataset,
    get_imagenet_test_loaders,
    create_dbpedia_dataset,
    get_dbpedia_test_loaders,
)
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.data.dataset import get_indices
from src.utils.metric import ConfidenceMIA, JSDivergence, ToWAccuracy, ToWMIA
from src.utils.utility import load_model, set_seed
from src.model.unlearn import NoisyUnlearnModel
from transformers import AutoTokenizer
from torchvision.models import Swin_V2_B_Weights


class EvalDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, indices):
        super().__init__()
        self.dataset = dataset
        self.indices = indices

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, index):
        idx = self.indices[index]
        return self.dataset[idx]


def select_dataset(dataset_name: str, train=True):
    ds = None
    if dataset_name == "mnist":
        ds = create_mnist_dataset(train=train)
    elif dataset_name == "cifar10":
        ds = create_cifar_dataset(train=train)
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        ds = create_cifar_dataset(train=train, cifar_100=True)
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_ag_news_dataset(tokenizer=tokenizer, train=train)
    elif dataset_name == "imagenet":
        ds = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=train
        )
    elif dataset_name == "imagenet_resnet":
        ds = create_imagenet_dataset(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=train
        )
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        ds = create_dbpedia_dataset(tokenizer=tokenizer, train=train)
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return ds


def get_loader(dataset_name, dataset, batch_size: int, shuffle=True):
    loader = None
    if dataset_name in [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "imagenet",
        "imagenet_resnet",
    ]:
        loader = torch.utils.data.DataLoader(
            dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False
        )
    elif dataset_name in ["ag_news", "dbpedia"]:
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = build_dataloader_text(
            dataset,
            tokenizer,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=False,
        )
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return loader


def get_test_loader(dataset_name, batch_size: int):
    loader = None
    if dataset_name == "mnist":
        loader = get_mnist_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "cifar10":
        loader = get_cifar_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        loader = get_cifar_test_loaders(
            forget_classes=None,
            batch_size=batch_size,
            cifar_100=True,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = get_ag_news_test_loaders(
            tokenizer=tokenizer,
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "imagenet":
        loader = get_imagenet_test_loaders(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(),
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "imagenet_resnet":
        loader = get_imagenet_test_loaders(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(),
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        loader = get_dbpedia_test_loaders(
            tokenizer=tokenizer,
            forget_classes=None,
            batch_size=batch_size,
            index_file=None,
        )["test_full_loader"]
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    return loader


def get_mia_loaders(train_ds, test_ds, retain_indices, seed):
    test_indices = list(range(len(test_ds)))
    random.seed(seed)
    random.shuffle(retain_indices)
    random.shuffle(test_indices)

    total_shadow_samples = min(len(retain_indices), len(test_ds))
    retain_indices = random.choices(retain_indices, k=total_shadow_samples)
    test_indices = random.choices(test_indices, k=total_shadow_samples)

    val_per_member = int(args.samples_per_member * args.val_percent)
    val_retain_indices = retain_indices[:val_per_member]
    val_test_indices = test_indices[:val_per_member]

    retain_indices = retain_indices[val_per_member:]
    test_indices = test_indices[val_per_member:]
    retain_indices = random.choices(
        retain_indices, k=min(args.samples_per_member, len(retain_indices))
    )
    test_indices = random.choices(
        test_indices, k=min(args.samples_per_member, len(test_indices))
    )
    retain_train_dataset = EvalDataset(train_ds, retain_indices)
    test_train_dataset = EvalDataset(test_ds, test_indices)
    retrain_eval_dataset = EvalDataset(train_ds, val_retain_indices)
    test_eval_dataset = EvalDataset(test_ds, val_test_indices)

    eval_ds = torch.utils.data.ConcatDataset([retrain_eval_dataset, test_eval_dataset])

    member_loader = get_loader(
        args.dataset, retain_train_dataset, batch_size=args.batch_size, shuffle=True
    )
    non_member_loader = get_loader(
        args.dataset, test_train_dataset, batch_size=args.batch_size, shuffle=True
    )
    eval_loader = get_loader(
        args.dataset, eval_ds, batch_size=args.batch_size, shuffle=False
    )
    return member_loader, non_member_loader, eval_loader


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


def select_unlearn_model(
    dataset_name: str,
    method_name: str,
    unlearn_checkpoint: str | None = None,
):
    un_model = None
    model = select_model(dataset_name=dataset_name, checkpoint=None)
    if model is None:
        raise Exception("Model cannot be None")

    if method_name == "uul":
        un_model = NoisyUnlearnModel(
            repr_model=model.repr_model,
            out_model=model.classifier,
            input_dim=model.in_features,
            hidden_dim=512,
            attn_hidden_dim=64,
            loss_fn=None,
            use_attn_loss=False,
            latch=True,
        )
    elif args.method in [
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
        "ind_uul",
    ]:
        un_model = select_model(
            dataset_name=dataset_name, checkpoint=unlearn_checkpoint
        )

    else:
        raise Exception(f"Method {args.method} is not defined")

    if unlearn_checkpoint is not None:
        un_model.load_state_dict(torch.load(unlearn_checkpoint)["model_state_dict"])
    return un_model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    method_choices = [
        "uul",
        "ind_uul",
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
    ]
    parser = argparse.ArgumentParser("Evaluating Unlearning Methods")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help="Specify the dataset to test unlearning. Choices are: [mnist, cifar10, cifar100, ag_news, imagenet, dbpedia]",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="*",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-sm",
        "--samples_per_member",
        type=int,
        default=5000,
        help="Number of samples from member and non-member set for training attack model",
    )
    parser.add_argument(
        "-vp",
        "--val_percent",
        type=float,
        default=0.1,
        help="Fraction of samples for validation of attack model",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-ckpts",
        "--checkpoints",
        type=str,
        default=None,
        help="Specify the path to folder with unlearned model for evaluation",
    )
    parser.add_argument(
        "-retr_ckpt",
        "--retrained_checkpoint",
        type=str,
        default=None,
        help="Specify the path to retrained model",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "eval"),
        help="Specify the path to save evaluation results",
    )
    parser.add_argument(
        "-m",
        "--method",
        type=str,
        choices=method_choices,
        required=True,
        help=f"Specify the unlearning method. Choices are: {method_choices}",
    )
    args = parser.parse_args()

    os.makedirs(os.path.join(args.save_path, args.method), exist_ok=True)
    set_seed(args.seed)
    train_ds = select_dataset(dataset_name=args.dataset, train=True)
    test_ds = select_dataset(dataset_name=args.dataset, train=False)

    test_loader = get_loader(args.dataset, test_ds, args.batch_size, shuffle=False)
    forget_indices, retain_indices = get_indices(
        train_ds,
        forget_labels=args.forget_labels,
        index_file=args.index_file,
        seed=args.seed,
    )
    forget_dataset = EvalDataset(train_ds, forget_indices)
    forget_loader = get_loader(
        args.dataset, forget_dataset, batch_size=args.batch_size, shuffle=False
    )
    retain_dataset = EvalDataset(train_ds, retain_indices)
    retain_loader = get_loader(
        args.dataset, retain_dataset, batch_size=args.batch_size, shuffle=False
    )
    member_loader, non_member_loader, eval_loader = get_mia_loaders(
        train_ds, test_ds, retain_indices, args.seed
    )

    retrained_model = select_model(
        dataset_name=args.dataset, checkpoint=args.retrained_checkpoint
    )

    mia_retrained = ConfidenceMIA(
        target_model=retrained_model, device=torch.device(f"cuda:{args.gpu}")
    )
    mia_retrained.fit(member_loader, non_member_loader, eval_loader)
    mia_score_retrained = mia_retrained.calculate(forget_loader)

    print("MIA Score Retrained", mia_score_retrained)

    mia_scores = []
    mia_gap_scores = []
    js_scores = []
    tow_acc_scores = []
    tow_mia_scores = []
    device = torch.device(f"cuda:{args.gpu}")
    for ckpt in os.listdir(args.checkpoints):
        if ckpt.endswith(".pt"):
            target_model = select_unlearn_model(
                args.dataset, args.method, os.path.join(args.checkpoints, ckpt)
            )
            mia = ConfidenceMIA(target_model=target_model, device=device)
            mia.fit(member_loader, non_member_loader, eval_loader)
            mia_score_method = mia.calculate(forget_loader)
            mia_scores.append(mia_score_method)
            print("MIA Score Unlearned", mia_score_method)

            mia_gap_score = abs(mia_score_retrained - mia_score_method)
            mia_gap_scores.append(mia_gap_score)
            print("MIA Gap: ", mia_gap_score)

            js = JSDivergence(retrained_model, target_model, device=device)
            js_score = js.calculate(test_loader)
            js_scores.append(js_score)
            print("JS Score", js_score)

            tow_acc = ToWAccuracy(target_model, retrained_model, device=device)
            tow_acc_score = tow_acc.calculate(
                forget_loader=forget_loader,
                retain_loader=retain_loader,
                test_loader=test_loader,
            )
            tow_acc_scores.append(tow_acc_score)
            print("ToW Accuracy Score: ", tow_acc_score)

            tow_mia = ToWMIA(
                target_model,
                retrained_model,
                mia_unlearned=mia,
                mia_retrained=mia_retrained,
                device=device,
            )
            tow_mia_score = tow_mia.calculate(
                forget_loader=forget_loader,
                retain_loader=retain_loader,
                test_loader=test_loader,
            )
            tow_mia_scores.append(tow_mia_score)
            print("ToW MIA Score: ", tow_mia_score)

    avg_mia = np.mean(mia_scores)
    avg_mia_gap = np.mean(mia_gap_scores)
    avg_jsd = np.mean(js_scores)
    avg_tow_acc = np.mean(tow_acc_scores)
    avg_tow_mia = np.mean(tow_mia_scores)
    var_mia = np.var(mia_scores)
    var_mia_gap = np.var(mia_gap_scores)
    var_jsd = np.var(js_scores)
    var_tow_acc = np.var(tow_acc_scores)
    var_tow_mia = np.var(tow_mia_scores)
    print("Results: ")
    print("MIA Score Average: ", avg_mia)
    print("MIA Score Variance: ", var_mia)
    print("MIA Gap Score Average: ", avg_mia_gap)
    print("MIA Gap Score Variance: ", var_mia_gap)
    print("JSD Score Average: ", avg_jsd)
    print("JSD Score Variance: ", var_jsd)
    print("ToW Accuracy Average: ", avg_tow_acc)
    print("ToW Accuracy Variance: ", var_tow_acc)
    print("ToW MIA Average: ", avg_tow_mia)
    print("ToW MIA Variance: ", var_tow_mia)

    df_dict = {
        "retrained_mia": [mia_score_retrained],
        "mia_average": [avg_mia],
        "mia_variance": [var_mia],
        "mia_gap_average": [avg_mia_gap],
        "mia_gap_var": [var_mia_gap],
        "jsd_average": [avg_jsd],
        "jsd_variance": [var_jsd],
        "tow_acc_average": [avg_tow_acc],
        "tow_acc_variance": [var_tow_acc],
        "tow_mia_average": [avg_tow_mia],
        "tow_mia_variance": [var_tow_mia],
    }

    filename = f"{args.dataset}.csv"
    save_path = os.path.join(args.save_path, args.method, filename)
    df = pd.DataFrame.from_dict(df_dict)
    df.to_csv(save_path, index=False)
    print("Results saved at: ", save_path)
