import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from tqdm.auto import tqdm
from src.utils.utility import get_lr_scheduler

"""for SCRUB: imported from https://github.com/HobbitLong/RepDistiller"""


class DistillKL(nn.Module):
    """Distilling the Knowledge in a Neural Network"""

    def __init__(self, T):
        super(DistillKL, self).__init__()
        self.T = T

    def forward(self, y_s, y_t):
        p_s = F.log_softmax(y_s / self.T, dim=1)
        p_t = F.softmax(y_t / self.T, dim=1)
        loss = F.kl_div(p_s, p_t, reduction="batchmean") * self.T**2
        return loss


def calculate_test_loss(
    teacher_model,
    student_model,
    val_retain_loader,
    val_forget_loader,
    loss_fn_classification,
    loss_fn_kl,
    beta,
    gamma,
    device,
    epoch,
):
    teacher_model.eval()
    student_model.eval()

    retain_losses = []
    forget_losses = []

    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(
            tqdm(val_retain_loader, total=len(val_retain_loader))
        ):
            data, target = data.to(device), target.to(device)
            logit_s = student_model(data)
            logit_t = teacher_model(data)

            loss_cls = loss_fn_classification(logit_s.logits, target)
            loss_div = loss_fn_kl(logit_s.logits, logit_t.logits)

            loss = (gamma * loss_cls) + (beta * loss_div)
            retain_losses.append(loss)
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(
            tqdm(val_forget_loader, total=len(val_forget_loader))
        ):
            data, target = data.to(device), target.to(device)
            logit_s = student_model(data)
            logit_t = teacher_model(data)

            loss_cls = loss_fn_classification(logit_s.logits, target)
            loss_div = loss_fn_kl(logit_s.logits, logit_t.logits)
            loss = -loss_div
            forget_losses.append(loss)
    retain_loss = torch.mean(torch.tensor(retain_losses)).item()
    forget_loss = torch.mean(torch.tensor(forget_losses)).item()
    total_loss = retain_loss + forget_loss
    print("Val Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return total_loss


def train_distill(
    loader,
    student_model,
    teacher_model,
    optimizer,
    loss_fn_classification,
    loss_fn_kl,
    gamma,
    beta,
    epoch,
    device,
    lr_scheduler=None,
    objective="minimize",
):
    teacher_model.eval()
    student_model.train()

    kd_losses = []

    for batch_idx, (data, target) in enumerate(tqdm(loader, total=len(loader))):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()

        logit_s = student_model(data)

        with torch.no_grad():
            logit_t = teacher_model(data)

        loss_cls = loss_fn_classification(logit_s.logits, target)
        loss_div = loss_fn_kl(logit_s.logits, logit_t.logits)

        if objective == "minimize":
            loss = (gamma * loss_cls) + (beta * loss_div)
        elif objective == "maximize":
            loss = -loss_div
        else:
            raise Exception("objective can be either maximize or minimize")
        loss.backward()
        optimizer.step()
        if lr_scheduler is not None:
            lr_scheduler.step(epoch + batch_idx / len(loader))
        kd_losses.append(loss)
    total_loss = torch.mean(torch.tensor(kd_losses)).item()
    print("Train Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return student_model, total_loss


def train_unlearn(
    student_model,
    teacher_model,
    forget_loader,
    retain_loader,
    device,
    lr,
    num_epochs,
    beta,
    gamma,
    forget_steps,
    val_retain_loader=None,
    val_forget_loader=None,
    lr_scheduler: None | str = None,
    patience=5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    student_model.to(device)
    teacher_model.to(device)
    optimizer = optim.AdamW(
        student_model.parameters(), lr=lr, weight_decay=weight_decay
    )
    lrs = get_lr_scheduler(lr_scheduler=lr_scheduler, optimizer=optimizer)
    best_loss = torch.inf
    counter = 0
    for e in range(1, num_epochs + 1):
        if e < forget_steps:
            student_model, epoch_loss = train_distill(
                loader=forget_loader,
                student_model=student_model,
                teacher_model=teacher_model,
                optimizer=optimizer,
                loss_fn_classification=nn.CrossEntropyLoss(),
                loss_fn_kl=DistillKL(T=2),
                gamma=gamma,
                beta=beta,
                epoch=e,
                device=device,
                lr_scheduler=(
                    lrs.scheduler
                    if lrs.scheduler is not None and lrs.is_batch_scheduler
                    else None
                ),
                objective="maximize",
            )
        student_model, epoch_loss = train_distill(
            loader=retain_loader,
            student_model=student_model,
            teacher_model=teacher_model,
            optimizer=optimizer,
            loss_fn_classification=nn.CrossEntropyLoss(),
            loss_fn_kl=DistillKL(T=2),
            gamma=gamma,
            beta=beta,
            epoch=e,
            device=device,
            lr_scheduler=(
                lrs.scheduler
                if lrs.scheduler is not None and lrs.is_batch_scheduler
                else None
            ),
            objective="minimize",
        )

        if val_retain_loader is not None and val_forget_loader is not None:
            val_loss = calculate_test_loss(
                teacher_model,
                student_model,
                val_retain_loader,
                val_forget_loader,
                nn.CrossEntropyLoss(),
                DistillKL(T=2),
                beta,
                gamma,
                device,
                epoch=e,
            )

            if lrs.scheduler is not None and (not lrs.is_batch_scheduler):
                lrs.scheduler.step(val_loss)

            if best_loss > val_loss:
                best_loss = val_loss
                counter = 0
                if save_path is not None:
                    torch.save(
                        {
                            "best_epoch": e + 1,
                            "model_state_dict": student_model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "loss": best_loss,
                        },
                        save_path,
                    )

            counter += 1

            if counter > patience:
                print(
                    "Early Stopping: Epoch {} \t Best Val Loss {}".format(e, best_loss)
                )
                break
        else:
            if save_path is not None:
                torch.save(
                    {
                        "best_epoch": e + 1,
                        "model_state_dict": student_model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": epoch_loss,
                    },
                    save_path,
                )
