import torch
import torch.optim as optim
from tqdm.auto import tqdm
from src.utils.utility import get_lr_scheduler


def calculate_test_loss(
    model, retain_val_loader, forget_val_loader, loss_fn, num_classes, device, epoch
):
    model.eval()
    forget_losses = []
    retain_losses = []
    with torch.no_grad():
        for batch_idx, (data, target) in enumerate(
            tqdm(forget_val_loader, total=len(forget_val_loader))
        ):
            data = data.to(device)
            target = torch.randint(0, num_classes, target.shape).to(device)
            output = model(data)
            loss = loss_fn(output.logits, target)
            forget_losses.append(loss)
        for batch_idx, (data, target) in enumerate(
            tqdm(retain_val_loader, total=len(retain_val_loader))
        ):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = loss_fn(output.logits, target)
            retain_losses.append(loss)

        total_forget_loss = torch.mean(torch.tensor(forget_losses)).item()
        total_retain_loss = torch.mean(torch.tensor(retain_losses)).item()
        total_loss = total_forget_loss + total_retain_loss
        print(
            "Val Epoch: {} \tLoss: {} \tForget Loss: {} \tRetain Loss: {}".format(
                epoch, total_loss, total_forget_loss, total_retain_loss
            )
        )
    return total_loss


def train_step_rl(
    model,
    forget_loader,
    retain_loader,
    num_classes,
    optimizer,
    loss_fn,
    epoch,
    device,
    mask=None,
):
    model.train()
    forget_losses = []
    retain_losses = []
    for batch_idx, (data, target) in enumerate(
        tqdm(forget_loader, total=len(forget_loader))
    ):
        data = data.to(device)
        target = torch.randint(0, num_classes, target.shape).to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output.logits, target)
        loss.backward()
        if mask is not None:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    param.grad *= mask[name].to(device)
        optimizer.step()
        forget_losses.append(loss)
    for batch_idx, (data, target) in enumerate(
        tqdm(retain_loader, total=len(retain_loader))
    ):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output.logits, target)
        loss.backward()
        if mask is not None:
            for name, param in model.named_parameters():
                if param.grad is not None:
                    param.grad *= mask[name].to(device)
        optimizer.step()
        retain_losses.append(loss)

    total_forget_loss = torch.mean(torch.tensor(forget_losses)).item()
    total_retain_loss = torch.mean(torch.tensor(retain_losses)).item()
    total_loss = total_forget_loss + total_retain_loss
    print(
        "Train Epoch: {} \tLoss: {} \tForget Loss: {} \tRetain Loss: {}".format(
            epoch, total_loss, total_forget_loss, total_retain_loss
        )
    )
    return model, total_loss


def train_rl(
    model,
    forget_loader,
    retain_loader,
    loss_fn,
    num_classes,
    device,
    lr,
    num_epochs,
    mask=None,
    val_retain_loader=None,
    val_forget_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    model.to(device)
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    lrs = get_lr_scheduler(lr_scheduler=lr_scheduler, optimizer=optimizer)
    best_loss = torch.inf
    counter = 0
    for e in range(1, num_epochs + 1):
        model, epoch_loss = train_step_rl(
            model=model,
            forget_loader=forget_loader,
            retain_loader=retain_loader,
            num_classes=num_classes,
            optimizer=optimizer,
            loss_fn=loss_fn,
            epoch=e,
            device=device,
            mask=mask,
        )

        if val_forget_loader is not None and val_retain_loader is not None:
            val_loss = calculate_test_loss(
                model,
                val_retain_loader,
                val_forget_loader,
                loss_fn,
                num_classes,
                device,
                epoch=e,
            )

            if lrs.scheduler is not None and (not lrs.is_batch_scheduler):
                lrs.scheduler.step(val_loss)

            if best_loss > val_loss:
                best_loss = val_loss
                counter = 0
                if save_path is not None:
                    torch.save(
                        {
                            "best_epoch": e + 1,
                            "model_state_dict": model.state_dict(),
                            "optimizer_state_dict": optimizer.state_dict(),
                            "loss": best_loss,
                        },
                        save_path,
                    )

            counter += 1

            if counter > patience:
                print(
                    "Early Stopping: Epoch {} \t Best Val Loss {}".format(e, best_loss)
                )
                break
        else:
            if save_path is not None:
                torch.save(
                    {
                        "best_epoch": e + 1,
                        "model_state_dict": model.state_dict(),
                        "optimizer_state_dict": optimizer.state_dict(),
                        "loss": epoch_loss,
                    },
                    save_path,
                )

    return model


def train_unlearn(
    model,
    forget_loader,
    retain_loader,
    num_classes,
    loss_fn,
    device,
    lr,
    num_epochs,
    val_retain_loader=None,
    val_forget_loader=None,
    lr_scheduler: None | str = None,
    patience: int = 5,
    weight_decay: float = 1e-6,
    save_path: None | str = None,
):
    return train_rl(
        model=model,
        forget_loader=forget_loader,
        retain_loader=retain_loader,
        num_classes=num_classes,
        loss_fn=loss_fn,
        device=device,
        lr=lr,
        num_epochs=num_epochs,
        mask=None,
        val_forget_loader=val_forget_loader,
        val_retain_loader=val_retain_loader,
        lr_scheduler=lr_scheduler,
        patience=patience,
        weight_decay=weight_decay,
        save_path=save_path,
    )
