import argparse
import os
import time
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from src.model.classification import (
    MnistConvNet,
    CIFARConvNet,
    CIFAR100ConvNet,
    ResNet101Classifier,
    SwinClassifier,
    BertClassifier,
)
from src.model.unlearn import NoisyUnlearnModel
from src.utils.utility import load_model
from src.utils.utility import get_lr_scheduler
from src.utils.test import test_unlearn_classification
from unlearning import select_dataset
from tqdm.auto import tqdm
from src.utils.metric import l2_distance_between_models


def select_model(dataset_name: str, checkpoint: str | None = None):
    model = None
    if dataset_name == "mnist":
        model = MnistConvNet()
    elif dataset_name == "cifar10":
        model = CIFARConvNet()
    elif dataset_name == "cifar100":
        model = SwinClassifier(num_labels=100)
    elif dataset_name == "cifar100_resnet":
        model = ResNet101Classifier(num_labels=100)
    elif dataset_name == "cifar100_conv":
        model = CIFAR100ConvNet()
    elif dataset_name == "ag_news":
        model = BertClassifier(num_labels=4, model_name="bert-base-uncased")
    elif dataset_name == "imagenet":
        model = SwinClassifier(num_labels=1000)
    elif dataset_name == "imagenet_resnet":
        model = ResNet101Classifier(num_labels=1000)
    elif dataset_name == "dbpedia":
        model = BertClassifier(num_labels=14, model_name="bert-base-uncased")
    else:
        raise Exception(f"dataset {dataset_name} is not defined")
    if checkpoint is not None:
        model = load_model(model=model, checkpoint_path=checkpoint)
    return model


def select_unlearn_model(dataset_name, model_checkpoint, noisy_checkpoint):
    un_model = None
    model = select_model(dataset_name=dataset_name, checkpoint=model_checkpoint)
    copy_model = select_model(dataset_name=dataset_name, checkpoint=model_checkpoint)
    if model is None:
        raise Exception("Model cannot be None")

    un_model = NoisyUnlearnModel(
        repr_model=copy_model.repr_model,
        out_model=copy_model.classifier,
        input_dim=copy_model.in_features,
        hidden_dim=512,
        attn_hidden_dim=64,
        loss_fn=None,
        use_attn_loss=False,
        latch=True,
    )
    un_model = load_model(un_model, noisy_checkpoint)
    return model, un_model


class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""

    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.T, dim=1)
        p_t = F.softmax(y_t / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction="batchmean") * (self.T**2)
        return loss


class IDLoss(nn.Module):
    def __init__(self, alpha=1 / 3, beta=1 / 3, gamma=1 / 3, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.mse = nn.MSELoss()
        self.ce = nn.CrossEntropyLoss()
        self.kl = DistillKL(T=2)
        self.alpha = alpha
        self.beta = beta
        self.gamma = gamma

    def forward(self, repr_feat, repr_noisy, y_pred, y_noisy, y_target):
        mse_loss = self.mse(repr_feat, repr_noisy)
        ce_loss = self.ce(y_pred, y_target)
        kl_loss = self.kl(y_pred, y_noisy)
        return self.alpha * mse_loss + self.beta * ce_loss + self.gamma * kl_loss


def get_student_pred(model, data):
    repr_ = model.repr_model(data)
    y_pred = model.classifier(repr_)
    return repr_, y_pred


def get_teacher_pred(model, data):
    with torch.no_grad():
        repr_ = model.repr_model(data)
        noisy_repr = model.latch_model(repr_).noisy_repr
        y_noisy = model.out_model(noisy_repr)
    return noisy_repr, y_noisy


def test_id_model_step(
    model,
    noisy_model,
    loader,
    loss_fn,
    epoch,
    device,
):
    noisy_model.eval()
    model.eval()
    losses = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(tqdm(loader, total=len(loader))):
            data, target = data.to(device), target.to(device)
            repr_feat, y_pred = get_student_pred(model=model, data=data)
            repr_noisy, y_noisy = get_teacher_pred(model=noisy_model, data=data)
            loss = loss_fn(repr_feat, repr_noisy, y_pred, y_noisy, target)
            losses.append(loss)
    total_loss = torch.mean(torch.tensor(losses)).item()
    print("Test Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return total_loss


def train_id_model_step(
    model,
    noisy_model,
    train_loader,
    optimizer,
    loss_fn,
    epoch,
    device,
    lr_scheduler=None,
):
    noisy_model.eval()
    model.train()
    losses = []
    for batch_idx, (data, target) in enumerate(
        tqdm(train_loader, total=len(train_loader))
    ):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        repr_feat, y_pred = get_student_pred(model=model, data=data)
        repr_noisy, y_noisy = get_teacher_pred(model=noisy_model, data=data)
        loss = loss_fn(repr_feat, repr_noisy, y_pred, y_noisy, target)
        loss.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step(epoch + batch_idx / len(train_loader))
        losses.append(loss)
    total_loss = torch.mean(torch.tensor(losses)).item()
    print("Train Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return model, total_loss


def train_id_model(
    model,
    noisy_model,
    train_loader,
    device,
    lr,
    num_epochs,
    alpha=1 / 3,
    beta=1 / 3,
    gamma=1 / 3,
    val_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    model.to(device)
    noisy_model.to(device)
    loss_fn = IDLoss(alpha=alpha, beta=beta, gamma=gamma)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    lrs = get_lr_scheduler(lr_scheduler=lr_scheduler, optimizer=optimizer)
    best_loss = torch.inf
    counter = 0
    for e in range(1, num_epochs + 1):
        model, epoch_loss = train_id_model_step(
            model=model,
            noisy_model=noisy_model,
            train_loader=train_loader,
            optimizer=optimizer,
            loss_fn=loss_fn,
            epoch=e,
            device=device,
            lr_scheduler=(
                lrs.scheduler
                if lrs.scheduler is not None and lrs.is_batch_scheduler
                else None
            ),
        )

        if val_loader is not None:
            val_loss = test_id_model_step(
                model=model,
                noisy_model=noisy_model,
                loader=val_loader,
                loss_fn=loss_fn,
                epoch=e,
                device=device,
            )

            if lrs.scheduler is not None and (not lrs.is_batch_scheduler):
                lrs.scheduler.step(val_loss)

            if best_loss > val_loss:
                best_loss = val_loss
                counter = 0
                if save_path is not None:
                    torch.save(
                        {
                            "best_epoch": e + 1,
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "loss": best_loss,
                        },
                        save_path,
                    )

            counter += 1

            if counter > patience:
                print(
                    "Early Stopping: Epoch {} \t Best Val Loss {}".format(e, best_loss)
                )
                break
        else:
            if save_path is not None:
                torch.save(
                    {
                        "best_epoch": e + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": epoch_loss,
                    },
                    save_path,
                )

    return model


def freeze_student(model):
    for param in model.parameters():
        param.requires_grad = False
    for layer in reversed(model.repr_model):
        params = list(layer.parameters())
        if params:
            for param in params:
                param.requires_grad = True
            break
    return model


def unfreeze_student(model):
    for param in model.parameters():
        param.requires_grad = True
    return model


if __name__ == "__main__":

    dataset_choices = [
        "mnist",
        "cifar10",
        "cifar100",
        "cifar100_resnet",
        "cifar100_conv",
        "ag_news",
        "imagenet",
        "imagenet_resnet",
        "dbpedia",
    ]
    parser = argparse.ArgumentParser("Testing Unlearning")
    parser.add_argument(
        "-g",
        "--gpu",
        type=int,
        choices=[0, 1, 2, 3, 4, 5, 6, 7],
        default=0,
        help="Specify GPU index to use",
    )
    parser.add_argument(
        "-dt",
        "--dataset",
        type=str,
        choices=dataset_choices,
        required=True,
        help=f"Specify the dataset to test unlearning. Choices are: {dataset_choices}",
    )
    parser.add_argument(
        "-forgets",
        "--forget_labels",
        nargs="+",
        type=int,
        required=False,
        default=None,
        help="List of forget labels in integer",
    )
    parser.add_argument(
        "-nr",
        "--num_retains",
        type=int,
        default=2000,
        help="Number of retain samples",
    )
    parser.add_argument(
        "-nf",
        "--num_forgets",
        type=int,
        default=1000,
        help="Number of forget samples",
    )
    parser.add_argument(
        "-bs",
        "--batch_size",
        type=int,
        default=1024,
        help="Dataloader batch size",
    )
    parser.add_argument(
        "-p",
        "--patience",
        type=int,
        default=5,
        help="Early stopping patience level",
    )
    parser.add_argument(
        "-lr",
        "--learning_rate",
        type=float,
        default=1e-3,
        help="Learning rate for training",
    )
    parser.add_argument(
        "-alpha",
        "--alpha",
        type=float,
        default=1 / 3,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-beta",
        "--beta",
        type=float,
        default=1 / 3,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-gamma",
        "--gamma",
        type=float,
        default=1 / 3,
        help="Hyperparameter alpha for loss function",
    )
    parser.add_argument(
        "-wd",
        "--weight_decay",
        type=float,
        default=1e-3,
        help="L2 weight decay",
    )
    parser.add_argument(
        "-s",
        "--seed",
        type=int,
        default=42,
        help="Seed value",
    )
    parser.add_argument(
        "-epochs",
        "--num_epochs",
        type=int,
        default=10000,
        help="Number of Epochs",
    )
    parser.add_argument(
        "-index",
        "--index_file",
        type=str,
        default=None,
        help="Specify the path to dataset index file",
    )
    parser.add_argument(
        "-test_index",
        "--test_index_file",
        type=str,
        default=None,
        help="Specify the path to test dataset index file",
    )
    parser.add_argument(
        "-lrs",
        "--lr_scheduler",
        type=str,
        choices=["cswr", "rlrp"],
        default=None,
        help="Specify the learning rate scheduler to use. Choices are: [cswr, rlrp]",
    )
    parser.add_argument(
        "-ckpt",
        "--checkpoint",
        type=str,
        default=None,
        help="Specify the path to scratch model",
    )
    parser.add_argument(
        "-nc",
        "--noisy_checkpoints",
        type=str,
        default=None,
        help="Specify the folder path to trained unlearned model checkpoints",
    )
    parser.add_argument(
        "-sp",
        "--save_path",
        type=str,
        default=os.path.join("outputs", "indistinct"),
        help="Specify the path to save checkpoint",
    )
    args = parser.parse_args()

    exp_name = f"labels:{args.forget_labels}-retain:{args.num_retains}"
    model_path = os.path.join(args.save_path, args.dataset, "models")
    result_path = os.path.join(args.save_path, args.dataset, "results")
    os.makedirs(model_path, exist_ok=True)
    os.makedirs(result_path, exist_ok=True)

    datasets = select_dataset(
        dataset_name=args.dataset,
        method_name="ft",
        forget_labels=args.forget_labels,
        num_retains=args.num_retains,
        num_forgets=args.num_forgets,
        index_file=args.index_file,
        test_index_file=args.test_index_file,
        batch_size=args.batch_size,
        shuffle=True,
        need_val=False,
        seed=args.seed,
    )

    orig_data = {
        "full": {"loss": [], "acc": []},
        "retain": {"loss": [], "acc": []},
        "forget": {"loss": [], "acc": []},
    }
    unlearn_data = {
        "full": {"loss": [], "acc": []},
        "retain": {"loss": [], "acc": []},
        "forget": {"loss": [], "acc": []},
    }

    orig_model = select_model(args.dataset, checkpoint=args.checkpoint)
    etas = []
    l2_norms = []

    for i, noisy_ckpt in enumerate(os.listdir(args.noisy_checkpoints)):
        model, noisy_model = select_unlearn_model(
            dataset_name=args.dataset,
            model_checkpoint=args.checkpoint,
            noisy_checkpoint=os.path.join(args.noisy_checkpoints, noisy_ckpt),
        )

        model = unfreeze_student(model=model)

        device = torch.device(f"cuda:{args.gpu}")

        filename = f"{args.dataset}-{exp_name}-{i+1}"
        save_path = os.path.join(model_path, f"{filename}.pt")

        start_time = time.time()
        train_id_model(
            model=model,
            noisy_model=noisy_model,
            train_loader=datasets["unlearn_loader"],
            device=device,
            lr=args.learning_rate,
            num_epochs=args.num_epochs,
            alpha=args.alpha,
            beta=args.beta,
            gamma=args.gamma,
            val_loader=None,
            lr_scheduler=args.lr_scheduler,
            patience=args.patience,
            weight_decay=args.weight_decay,
            save_path=save_path,
        )

        model = load_model(model, save_path)

        test_loss, accuracy = test_unlearn_classification(
            model=noisy_model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_full_loader"],
        )
        orig_data["full"]["loss"].append(test_loss)
        orig_data["full"]["acc"].append(accuracy)

        if datasets["test_retain_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=noisy_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_retain_loader"],
            )
            orig_data["retain"]["loss"].append(test_loss)
            orig_data["retain"]["acc"].append(accuracy)
        else:
            orig_data["retain"]["loss"].append(0)
            orig_data["retain"]["acc"].append(0)

        if datasets["test_forget_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=noisy_model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_forget_loader"],
            )
            orig_data["forget"]["loss"].append(test_loss)
            orig_data["forget"]["acc"].append(accuracy)
        else:
            orig_data["retain"]["loss"].append(0)
            orig_data["retain"]["acc"].append(0)

        test_loss, accuracy = test_unlearn_classification(
            model=model,
            device=torch.device(f"cuda:{args.gpu}"),
            test_loader=datasets["test_full_loader"],
        )
        unlearn_data["full"]["loss"].append(test_loss)
        unlearn_data["full"]["acc"].append(accuracy)

        if datasets["test_retain_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_retain_loader"],
            )
            unlearn_data["retain"]["loss"].append(test_loss)
            unlearn_data["retain"]["acc"].append(accuracy)
        else:
            unlearn_data["retain"]["loss"].append(0)
            unlearn_data["retain"]["acc"].append(0)

        if datasets["test_forget_loader"] is not None:
            test_loss, accuracy = test_unlearn_classification(
                model=model,
                device=torch.device(f"cuda:{args.gpu}"),
                test_loader=datasets["test_forget_loader"],
            )
            unlearn_data["forget"]["loss"].append(test_loss)
            unlearn_data["forget"]["acc"].append(accuracy)
        else:
            unlearn_data["forget"]["loss"].append(0)
            unlearn_data["forget"]["acc"].append(0)

        l2_norms.append(l2_distance_between_models(model_a=model, model_b=orig_model))
        end_time = time.time()
        etas.append(end_time - start_time)

    mean_time = np.mean(etas)
    var_time = np.var(etas)

    mean_l2 = np.mean(l2_norms)
    var_l2 = np.var(l2_norms)

    with open(os.path.join(result_path, f"{filename}-time.txt"), "w") as handle:
        handle.write(f"Time (mean): {mean_time}s  Time (var): {var_time}")

    with open(os.path.join(result_path, f"{filename}-l2.txt"), "w") as handle:
        handle.write(f"L2 (mean): {mean_l2}  L2 (var): {var_l2}")

    df_dict = {
        "noisy": [],
        "dataset": [],
        "cross_entropy-mean": [],
        "cross_entropy-std": [],
        "accuracy-mean": [],
        "accuracy-std": [],
    }
    for i in ["full", "retain", "forget"]:
        df_data = orig_data[i]
        df_dict["noisy"].append(0)
        df_dict["dataset"].append(i)
        df_dict["cross_entropy-mean"].append(np.mean(df_data["loss"]))
        df_dict["cross_entropy-std"].append(np.std(df_data["loss"]))
        df_dict["accuracy-mean"].append(np.mean(df_data["acc"]))
        df_dict["accuracy-std"].append(np.std(df_data["acc"]))

    for i in ["full", "retain", "forget"]:
        df_data = unlearn_data[i]
        df_dict["noisy"].append(1)
        df_dict["dataset"].append(i)
        df_dict["cross_entropy-mean"].append(np.mean(df_data["loss"]))
        df_dict["cross_entropy-std"].append(np.std(df_data["loss"]))
        df_dict["accuracy-mean"].append(np.mean(df_data["acc"]))
        df_dict["accuracy-std"].append(np.std(df_data["acc"]))

    df = pd.DataFrame.from_dict(df_dict)
    df.to_csv(
        os.path.join(result_path, f"{filename}.csv"),
        index=False,
    )
    print(f"Results saved in {filename}.csv")

    print("==" * 20)
    print(f"Successfully Trained model on {args.dataset}")
    print("==" * 20)
