import argparse
import os
import random
import copy
import time
import torch
import torch.nn as nn
import pandas as pd
import numpy as np
from transformers import AutoTokenizer
from torchvision.models import (
    ResNet101_Weights,
    Swin_V2_B_Weights,
)
from src.data.dataset import (
    UnlearnDiscriminativeDataset,
    DiscriminativeDataset,
    get_indices,
)
from src.data.utils import (
    get_mnist_test_loaders,
    create_mnist_dataset,
    get_cifar_test_loaders,
    create_cifar_dataset,
    get_ag_news_test_loaders,
    create_ag_news_dataset,
    build_dataloader_text,
    create_imagenet_dataset,
    get_imagenet_test_loaders,
    create_dbpedia_dataset,
    get_dbpedia_test_loaders,
)
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.model.unlearn import NoisyUnlearnModel
from src.utils.loss import KLDivContrastiveLoss, MSEContrastiveLoss
from src.unlearn import (
    uul,
    finetune,
    gradient_ascent,
    neg_grad_plus,
    l1_sparse,
    random_label,
    wood_fisher,
    scrub,
    salun,
)
from src.utils.test import test_unlearn_classification
from src.utils.utility import load_model, set_seed


def proportional_disjoint_split(a, b, per_value, seed):
    total = len(a) + len(b)
    sample_total = int(per_value * total)
    if total == 0 or sample_total == 0:
        return [], a, [], b

    num_u = int(round(len(a) / total * sample_total))
    num_y = sample_total - num_u

    random.seed(seed)
    u = random.sample(a, min(num_u, len(a)))
    v = [x for x in a if x not in u]

    random.seed(seed)
    y = random.sample(b, min(num_y, len(b)))
    z = [x for x in b if x not in y]

    return u, v, y, z


def gen_loaders(
    dataset,
    dataset_name,
    method_name,
    forget_labels,
    num_retains,
    num_forgets,
    index_file,
    batch_size,
    seed,
    tokenizer=None,
    shuffle=True,
    need_val=False,
):

    forget_indices, retain_indices = get_indices(
        dataset=dataset, forget_labels=forget_labels, index_file=index_file, seed=seed
    )
    if need_val:
        val_forget, train_forget, val_retain, train_retain = (
            proportional_disjoint_split(forget_indices, retain_indices, 0.1, seed=seed)
        )
    else:
        train_forget = forget_indices
        train_retain = retain_indices
    unlearn_dataset = get_method_dataset(
        method_name=method_name,
        dataset=dataset,
        forget_labels=forget_labels,
        num_retains=num_retains,
        num_forgets=num_forgets,
        index_file=None,
        indices_dict={"forget_indices": train_forget, "retain_indices": train_retain},
        seed=seed,
    )
    retain_dataset = get_method_dataset(
        method_name=method_name,
        dataset=dataset,
        forget_labels=forget_labels,
        num_retains=num_retains,
        num_forgets=None,
        index_file=None,
        indices_dict={"forget_indices": train_forget, "retain_indices": train_retain},
        seed=seed,
    )
    forget_dataset = get_method_dataset(
        method_name=method_name,
        dataset=dataset,
        forget_labels=forget_labels,
        num_retains=None,
        num_forgets=num_forgets,
        index_file=None,
        indices_dict={"forget_indices": train_forget, "retain_indices": train_retain},
        seed=seed,
    )
    val_unlearn_dataset = None
    retain_dataset_val = None
    forget_dataset_val = None
    loader_val = None
    forget_loader_val = None
    retain_loader_val = None
    if need_val:
        val_unlearn_dataset = get_method_dataset(
            method_name=method_name,
            dataset=dataset,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=None,
            indices_dict={"forget_indices": val_forget, "retain_indices": val_retain},
            seed=seed,
        )
        retain_dataset_val = get_method_dataset(
            method_name=method_name,
            dataset=dataset,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=None,
            index_file=None,
            indices_dict={"forget_indices": val_forget, "retain_indices": val_retain},
            seed=seed,
        )
        forget_dataset_val = get_method_dataset(
            method_name=method_name,
            dataset=dataset,
            forget_labels=forget_labels,
            num_retains=None,
            num_forgets=num_forgets,
            index_file=None,
            indices_dict={"forget_indices": val_forget, "retain_indices": val_retain},
            seed=seed,
        )

    if dataset_name in ["ag_news", "dbpedia"]:
        loader_unlearn = build_dataloader_text(
            unlearn_dataset,
            tokenizer,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=False,
        )
        loader_retain = build_dataloader_text(
            retain_dataset,
            tokenizer,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=False,
        )
        loader_forget = build_dataloader_text(
            forget_dataset,
            tokenizer,
            batch_size=batch_size,
            shuffle=shuffle,
            drop_last=False,
        )
        if need_val:
            loader_val = build_dataloader_text(
                val_unlearn_dataset,
                tokenizer,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
            forget_loader_val = build_dataloader_text(
                forget_dataset_val,
                tokenizer,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
            retain_loader_val = build_dataloader_text(
                retain_dataset_val,
                tokenizer,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
    else:
        loader_unlearn = torch.utils.data.DataLoader(
            unlearn_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False
        )
        loader_retain = torch.utils.data.DataLoader(
            retain_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False
        )
        loader_forget = torch.utils.data.DataLoader(
            forget_dataset, batch_size=batch_size, shuffle=shuffle, drop_last=False
        )
        if need_val:
            loader_val = torch.utils.data.DataLoader(
                val_unlearn_dataset,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
            forget_loader_val = torch.utils.data.DataLoader(
                forget_dataset_val,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
            retain_loader_val = torch.utils.data.DataLoader(
                retain_dataset_val,
                batch_size=batch_size,
                shuffle=False,
                drop_last=False,
            )
    return {
        "unlearn_loader": loader_unlearn,
        "retain_loader": loader_retain,
        "forget_loader": loader_forget,
        "val_loader": loader_val,
        "val_forget_loader": forget_loader_val,
        "val_retain_loader": retain_loader_val,
    }


def select_dataset(
    dataset_name: str,
    method_name: str,
    forget_labels: list,
    num_retains: int,
    num_forgets: int,
    index_file: None | str,
    test_index_file: None | str,
    batch_size: int,
    shuffle: bool = True,
    need_val: bool = False,
    seed: int = 42,
):
    if dataset_name == "mnist":
        dataset = create_mnist_dataset(train=True)
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=None,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_mnist_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 10,
            }
        )
        return dataloaders
    elif dataset_name == "cifar10":
        dataset = create_cifar_dataset(train=True)
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=None,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_cifar_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 10,
            }
        )
        return dataloaders
    elif dataset_name in ["cifar100", "cifar100_resnet", "cifar100_conv"]:
        dataset = create_cifar_dataset(train=True, cifar_100=True)
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=None,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_cifar_test_loaders(
            forget_classes=forget_labels,
            batch_size=batch_size,
            cifar_100=True,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 100,
            }
        )
        return dataloaders
    elif dataset_name == "ag_news":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        dataset = create_ag_news_dataset(tokenizer=tokenizer, train=True)
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=tokenizer,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_ag_news_test_loaders(
            tokenizer=tokenizer,
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 4,
            }
        )
        return dataloaders
    elif dataset_name == "imagenet":
        dataset = create_imagenet_dataset(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(), train=True
        )
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=None,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_imagenet_test_loaders(
            preprocess_fn=Swin_V2_B_Weights.DEFAULT.transforms(),
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 1000,
            }
        )
        return dataloaders
    elif dataset_name == "imagenet_resnet":
        dataset = create_imagenet_dataset(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(), train=True
        )
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=None,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_imagenet_test_loaders(
            preprocess_fn=ResNet101_Weights.DEFAULT.transforms(),
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 1000,
            }
        )
        return dataloaders
    elif dataset_name == "dbpedia":
        model_name = "bert-base-uncased"
        tokenizer = AutoTokenizer.from_pretrained(model_name)
        dataset = create_dbpedia_dataset(tokenizer=tokenizer, train=True)
        loaders = gen_loaders(
            dataset=dataset,
            dataset_name=dataset_name,
            method_name=method_name,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            batch_size=batch_size,
            seed=seed,
            tokenizer=tokenizer,
            shuffle=shuffle,
            need_val=need_val,
        )
        dataloaders = get_dbpedia_test_loaders(
            tokenizer=tokenizer,
            forget_classes=forget_labels,
            batch_size=batch_size,
            index_file=test_index_file,
        )
        dataloaders.update(loaders)
        dataloaders.update(
            {
                "num_classes": 14,
            }
        )
        return dataloaders
    else:
        raise Exception(f"dataset {dataset_name} is not defined")


def get_method_dataset(
    method_name: str,
    dataset,
    forget_labels: list,
    num_retains: None | int,
    num_forgets: None | int,
    index_file: None | str,
    indices_dict: None | dict,
    seed: int = 42,
):
    ds = None
    if method_name == "uul":
        ds = UnlearnDiscriminativeDataset(
            dataset=dataset,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            indices_dict=indices_dict,
            seed=seed,
        )
    elif method_name in [
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
    ]:
        ds = DiscriminativeDataset(
            dataset=dataset,
            forget_labels=forget_labels,
            num_retains=num_retains,
            num_forgets=num_forgets,
            index_file=index_file,
            indices_dict=indices_dict,
            seed=seed,
        )
    else:
        raise Exception(f"Method {method_name} is not defined")
    return ds


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


def select_unlearn_model(args):
    un_model = None
    model = select_model(dataset_name=args.dataset, checkpoint=args.checkpoint)
    if model is None:
        raise Exception("Model cannot be None")

    if args.method == "uul":
        loss_fn = None
        if args.loss == "kl":
            loss_fn = KLDivContrastiveLoss()
        elif args.loss == "mse":
            loss_fn = MSEContrastiveLoss(margin=100)
        else:
            raise Exception("Contrastive loss not defined, choices are [kl, mse]")
        un_model = NoisyUnlearnModel(
            repr_model=model.repr_model,
            out_model=model.classifier,
            input_dim=model.in_features,
            hidden_dim=512,
            attn_hidden_dim=64,
            loss_fn=loss_fn,
            use_attn_loss=args.use_attn_loss,
            latch=True,
        )
    elif args.method in [
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
    ]:
        un_model = select_model(dataset_name=args.dataset, checkpoint=args.checkpoint)
    else:
        raise Exception(f"Method {args.method} is not defined")
    return model, un_model


if __name__ == "__main__":
    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    method_choices = [
        "uul",
        "ft",
        "ga",
        "neg_grad",
        "w_fisher",
        "l1_sparse",
        "random_label",
        "scrub",
        "salun",
    ]
    parser = argparse.ArgumentParser("Testing Unlearning")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-m",
        "--method",
        type=str,
        choices=method_choices,
        required=True,
        help=f"Specify the unlearning method. Choices are: {method_choices}",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help=f"Specify the dataset to test unlearning. Choices are: {dataset_choices}",
    )
    parser.add_argument(
        "-l",
        "--loss",
        type=str,
        choices=["kl", "mse"],
        default="kl",
        help="Specify the contrastive loss for unlearning. Choices are: [kl, mse]",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="*",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-nr",
        "--num_retains",
        type=int,
        default=2000,
        help="Number of retain samples",
    )
    parser.add_argument(
        "-nf",
        "--num_forgets",
        type=int,
        default=1000,
        help="Number of forget samples",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-p",
        "--patience",
        type=int,
        default=5,
        help="Early stopping patience level",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Learning rate for training",
    )
    parser.add_argument(
        "-alpha",
        "--alpha",
        type=float,
        default=0.1,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-beta",
        "--beta",
        type=float,
        default=0.1,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-gamma",
        "--gamma",
        type=float,
        default=0.1,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-wd",
        "--weight_decay",
        type=float,
        default=1e-3,
        help="L2 weight decay",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-f_step",
        "--forget_steps",
        type=int,
        default=10,
        help="Number of forget steps for scrub",
    )
    parser.add_argument(
        "-epochs",
        "--num_epochs",
        type=int,
        default=10000,
        help="Number of Epochs",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-lrs",
        "--lr_scheduler",
        type=str,
        choices=["cswr", "rlrp"],
        default=None,
        help="Specify the learning rate scheduler to use. Choices are: [cswr, rlrp]",
    )
    parser.add_argument(
        "-ckpt",
        "--checkpoint",
        type=str,
        default=None,
        help="Specify the path to trained model checkpoint",
    )
    parser.add_argument(
        "-uckpt",
        "--unlearn_checkpoint",
        type=str,
        default=None,
        help="Specify the path to trained unlearned model checkpoint",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "unlearn"),
        help="Specify the path to save checkpoint",
    )
    parser.add_argument(
        "-mp",
        "--mask_path",
        type=str,
        default=None,
        help="Specify the path to load masks for SalUn",
    )
    parser.add_argument(
        "-runs",
        "--num_runs",
        type=int,
        default=10,
        help="Number of experiment runs",
    )
    parser.add_argument(
        "-attn_loss",
        "--use_attn_loss",
        action="store_true",
        help="Use selection loss for selector network or not",
    )
    args = parser.parse_args()

    os.makedirs(args.save_path, exist_ok=True)
    os.makedirs(os.path.join(args.save_path, args.method, args.dataset), exist_ok=True)
    os.makedirs(
        os.path.join(args.save_path, args.method, args.dataset, "models"), exist_ok=True
    )
    os.makedirs(
        os.path.join(args.save_path, args.method, args.dataset, "results"),
        exist_ok=True,
    )

    exp_name = f"labels:{args.forget_labels}-retain:{args.num_retains}"

    orig_data = {
        "full": {"loss": [], "acc": []},
        "retain": {"loss": [], "acc": []},
        "forget": {"loss": [], "acc": []},
    }
    unlearn_data = {
        "full": {"loss": [], "acc": []},
        "retain": {"loss": [], "acc": []},
        "forget": {"loss": [], "acc": []},
    }

    datasets = select_dataset(
        args.dataset,
        args.method,
        args.forget_labels,
        args.num_retains,
        args.num_forgets,
        index_file=args.index_file,
        test_index_file=args.test_index_file,
        batch_size=args.batch_size,
        seed=args.seed,
    )

    path_gen = lambda folder, file_name: os.path.join(
        args.save_path,
        args.method,
        args.dataset,
        folder,
        file_name,
    )

    with open(
        path_gen(folder="results", file_name=f"{exp_name}.txt"),
        "w",
    ) as ds_file:
        ds_file.writelines(
            [
                f'Forget Samples: {datasets["unlearn_loader"].dataset.num_forgets}\n',
                f'Retain Samples: {datasets["unlearn_loader"].dataset.num_retains}\n',
                f'Total Samples: {len(datasets["unlearn_loader"].dataset)}',
            ]
        )

    etas = []

    for run in range(args.num_runs):

        seed = args.seed + run
        set_seed(seed)

        device = torch.device(f"cuda:{args.gpu}")
        orig_model, un_model = select_unlearn_model(args=args)
        start_time = time.time()
        if args.method == "uul":
            file_name = f"{exp_name}-forget:{args.num_forgets}-attn_loss:{args.use_attn_loss}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            uul.train_unlearn(
                un_model,
                datasets["unlearn_loader"],
                device=device,
                lr=args.learning_rate,
                num_epochs=args.num_epochs,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_loader=datasets["val_loader"],
            )
        elif args.method == "ft":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            finetune.train_unlearn(
                model=un_model,
                train_loader=datasets["retain_loader"],
                device=device,
                lr=args.learning_rate,
                num_epochs=args.num_epochs,
                loss_fn=loss_fn,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_loader=datasets["val_retain_loader"],
            )
        elif args.method == "l1_sparse":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            l1_sparse.train_unlearn(
                model=un_model,
                train_loader=datasets["retain_loader"],
                device=device,
                lr=args.learning_rate,
                alpha=args.alpha,
                num_epochs=args.num_epochs,
                loss_fn=loss_fn,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_loader=datasets["val_retain_loader"],
            )
        elif args.method == "ga":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            gradient_ascent.train_unlearn(
                model=un_model,
                train_loader=datasets["forget_loader"],
                device=device,
                lr=args.learning_rate,
                num_epochs=args.num_epochs,
                loss_fn=loss_fn,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_loader=None,
            )
        elif args.method == "neg_grad":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            neg_grad_plus.train_unlearn(
                model=un_model,
                retain_loader=datasets["retain_loader"],
                forget_loader=datasets["forget_loader"],
                device=device,
                lr=args.learning_rate,
                alpha=args.alpha,
                num_epochs=args.num_epochs,
                loss_fn=loss_fn,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_forget_loader=None,
                val_retain_loader=None,
            )
        elif args.method == "w_fisher":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            datasets = select_dataset(
                args.dataset,
                args.method,
                args.forget_labels,
                args.num_retains,
                args.num_forgets,
                index_file=args.index_file,
                test_index_file=args.test_index_file,
                batch_size=args.batch_size,
                shuffle=False,
                seed=args.seed,
            )
            batch_1_retain_loader = select_dataset(
                args.dataset,
                args.method,
                args.forget_labels,
                args.num_retains,
                args.num_forgets,
                index_file=args.index_file,
                test_index_file=args.test_index_file,
                batch_size=1,
                shuffle=False,
                seed=args.seed,
            )["retain_loader"]
            wood_fisher.train_unlearn(
                forget_loader=datasets["forget_loader"],
                retain_loader=datasets["retain_loader"],
                batch_1_retain_loader=batch_1_retain_loader,
                model=un_model,
                loss_fn=loss_fn,
                device=device,
                alpha=args.alpha,
                mask=None,
                save_path=save_path,
            )
        elif args.method == "random_label":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            random_label.train_unlearn(
                model=un_model,
                retain_loader=datasets["retain_loader"],
                forget_loader=datasets["forget_loader"],
                num_classes=datasets["num_classes"],
                loss_fn=loss_fn,
                device=device,
                num_epochs=args.num_epochs,
                lr=args.learning_rate,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_forget_loader=datasets["val_forget_loader"],
                val_retain_loader=datasets["val_retain_loader"],
            )
        elif args.method == "scrub":
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            scrub.train_unlearn(
                student_model=un_model,
                teacher_model=copy.deepcopy(un_model),
                forget_loader=datasets["forget_loader"],
                retain_loader=datasets["retain_loader"],
                device=device,
                lr=args.learning_rate,
                num_epochs=args.num_epochs,
                beta=args.beta,
                gamma=args.gamma,
                forget_steps=args.forget_steps,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_forget_loader=None,
                val_retain_loader=None,
            )
        elif args.method == "salun":
            loss_fn = None
            if args.loss == "kl":
                loss_fn = nn.CrossEntropyLoss()
            elif args.loss == "mse":
                loss_fn = nn.MSELoss()
            file_name = f"{exp_name}-name:{orig_model.name}-seed:{seed}"
            save_path = path_gen(folder="models", file_name=f"{file_name}.pt")
            if args.mask_path is None:
                raise Exception("Please provide mask path for SalUn")
            salun.train_unlearn(
                model=un_model,
                forget_loader=datasets["forget_loader"],
                retain_loader=datasets["retain_loader"],
                num_classes=datasets["num_classes"],
                loss_fn=loss_fn,
                device=device,
                lr=args.learning_rate,
                num_epochs=args.num_epochs,
                mask_path=args.mask_path,
                lr_scheduler=args.lr_scheduler,
                patience=args.patience,
                weight_decay=args.weight_decay,
                save_path=save_path,
                val_forget_loader=datasets["val_forget_loader"],
                val_retain_loader=datasets["val_retain_loader"],
            )
        else:
            raise Exception(f"Unlearning method {args.method} not defined")
        end_time = time.time()
        etas.append(end_time - start_time)

        un_model = load_model(un_model, save_path)

        test_loss, accuracy = test_unlearn_classification(
            model=orig_model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_full_loader"],
        )
        orig_data["full"]["loss"].append(test_loss)
        orig_data["full"]["acc"].append(accuracy)

        if datasets["test_retain_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=orig_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_retain_loader"],
            )
            orig_data["retain"]["loss"].append(test_loss)
            orig_data["retain"]["acc"].append(accuracy)
        else:
            orig_data["retain"]["loss"].append(0)
            orig_data["retain"]["acc"].append(0)

        if datasets["test_forget_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=orig_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_forget_loader"],
            )
            orig_data["forget"]["loss"].append(test_loss)
            orig_data["forget"]["acc"].append(accuracy)
        else:
            orig_data["forget"]["loss"].append(0)
            orig_data["forget"]["acc"].append(0)

        test_loss, accuracy = test_unlearn_classification(
            model=un_model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_full_loader"],
        )
        unlearn_data["full"]["loss"].append(test_loss)
        unlearn_data["full"]["acc"].append(accuracy)

        if datasets["test_retain_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=un_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_retain_loader"],
            )
            unlearn_data["retain"]["loss"].append(test_loss)
            unlearn_data["retain"]["acc"].append(accuracy)
        else:
            unlearn_data["retain"]["loss"].append(0)
            unlearn_data["retain"]["acc"].append(0)

        if datasets["test_forget_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=un_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_forget_loader"],
            )
            unlearn_data["forget"]["loss"].append(test_loss)
            unlearn_data["forget"]["acc"].append(accuracy)
        else:
            unlearn_data["forget"]["loss"].append(0)
            unlearn_data["forget"]["acc"].append(0)

    mean_time = np.mean(etas)
    var_time = np.var(etas)

    with open(
        path_gen(folder="results", file_name=f"{file_name}-time.txt"), "w"
    ) as handle:
        handle.write(f"Time (mean): {mean_time}s  Time (var): {var_time}")

    df_dict = {
        "unlearned": [],
        "dataset": [],
        "cross_entropy-mean": [],
        "cross_entropy-std": [],
        "accuracy-mean": [],
        "accuracy-std": [],
    }
    for i in ["full", "retain", "forget"]:
        df_data = orig_data[i]
        df_dict["unlearned"].append(0)
        df_dict["dataset"].append(i)
        df_dict["cross_entropy-mean"].append(np.mean(df_data["loss"]))
        df_dict["cross_entropy-std"].append(np.std(df_data["loss"]))
        df_dict["accuracy-mean"].append(np.mean(df_data["acc"]))
        df_dict["accuracy-std"].append(np.std(df_data["acc"]))

    for i in ["full", "retain", "forget"]:
        df_data = unlearn_data[i]
        df_dict["unlearned"].append(1)
        df_dict["dataset"].append(i)
        df_dict["cross_entropy-mean"].append(np.mean(df_data["loss"]))
        df_dict["cross_entropy-std"].append(np.std(df_data["loss"]))
        df_dict["accuracy-mean"].append(np.mean(df_data["acc"]))
        df_dict["accuracy-std"].append(np.std(df_data["acc"]))

    df = pd.DataFrame.from_dict(df_dict)
    df.to_csv(
        path_gen(folder="results", file_name=f"{file_name}.csv"),
        index=False,
    )
    print(f"Results saved in {file_name}.csv")

    print("==" * 20)
    print(f"Successfully Trained model on {args.dataset} using loss {args.loss}")
    print("==" * 20)
