import torch
import torch.optim as optim
from tqdm.auto import tqdm
from src.utils.utility import get_lr_scheduler


def calculate_test_loss(model, val_loader, loss_fn, device, epoch):
    model.eval()
    losses = []
    with torch.no_grad():
        for idx, (data, target) in enumerate(tqdm(val_loader, total=len(val_loader))):
            data, target = data.to(device), target.to(device)
            out = model(data, y=target)
            if out.loss is not None:
                loss = out.loss
            elif out.loss is None and loss_fn is not None:
                loss = loss_fn(out.noisy_repr, out.true_repr, target)
            else:
                raise Exception(
                    "No loss defined, either should be output from model or provide loss function in argument"
                )
            losses.append(loss)
        total_loss = torch.mean(torch.tensor(losses)).item()
        print("Val Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return total_loss


def train_unlearn_step(
    model, unlearn_loader, optimizer, device, epoch, loss_fn=None, scheduler=None
):
    model.train()
    losses = []
    for idx, (data, target) in enumerate(
        tqdm(unlearn_loader, total=len(unlearn_loader))
    ):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        out = model(data, y=target)
        if out.loss is not None:
            loss = out.loss
        elif out.loss is None and loss_fn is not None:
            loss = loss_fn(out.noisy_repr, out.true_repr, target)
        else:
            raise Exception(
                "No loss defined, either should be output from model or provide loss function in argument"
            )
        loss.backward()
        optimizer.step()
        if scheduler is not None:
            scheduler.step(epoch + idx / len(unlearn_loader))
        losses.append(loss)
    total_loss = torch.mean(torch.tensor(losses)).item()
    print("Train Epoch: {} \tLoss: {}".format(epoch, total_loss))
    return model, total_loss


def train_unlearn(
    model,
    unlearn_loader,
    device,
    lr,
    num_epochs,
    loss_fn=None,
    val_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    lrs = get_lr_scheduler(lr_scheduler=lr_scheduler, optimizer=optimizer)
    best_loss = torch.inf
    counter = 0
    for e in range(num_epochs):
        model, epoch_loss = train_unlearn_step(
            model,
            unlearn_loader,
            optimizer,
            device,
            e,
            loss_fn,
            (
                lrs.scheduler
                if lrs.scheduler is not None and lrs.is_batch_scheduler
                else None
            ),
        )

        if val_loader is not None:

            val_loss = calculate_test_loss(model, val_loader, loss_fn, device, epoch=e)

            if lrs.scheduler is not None and (not lrs.is_batch_scheduler):
                lrs.scheduler.step(val_loss)

            if best_loss > val_loss:
                best_loss = val_loss
                counter = 0
                if save_path is not None:
                    torch.save(
                        {
                            "best_epoch": e + 1,
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "loss": best_loss,
                        },
                        save_path,
                    )

            counter += 1

            if counter > patience:
                print("Early Stopping: Epoch {} \t Best Loss {}".format(e, best_loss))
                break
        else:
            if save_path is not None:
                torch.save(
                    {
                        "best_epoch": e + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": best_loss,
                    },
                    save_path,
                )

    return model
