"""An additional trainer implementation for more efficient training on a gpu.

Also implements a learning rate warmup protocol. The script is specifically set
up for training the MNIST models. It also implements an accuracy performance
metric.

For general training, see the trainer in /trainer/loop.py.
"""
from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch.nn import Module
    from torch import Tensor
    from collections.abc import Callable
    from setup.configuration import TrainArgs

import wandb
import torch

from helpers.logger import get_logger

def train(
        train_args: TrainArgs,
) -> None:
    """Run training loops until stopping criteria is met."""
    logger = get_logger()
    logger.level = 20

    logger.info(f"Training: {train_args.ckpt_name}")

    if not train_args.sweep:
        run = wandb.init(**train_args.wandb_kwargs)
        wandb.config = train_args.__dict__
    else:
        run = train_args.sw_run

    model = train_args.model.to(device=train_args.device)
    train_loader, test_loader = train_args.dataloaders
    train_loader = cache_data(
        train_loader,
        device=train_args.device,
        shuffle=True,
        num_workers=0,
    )
    test_loader = cache_data(
        test_loader,
        device=train_args.device,
        shuffle=False,
        num_workers=0,
    )
    assert len(test_loader)==1
    optim = train_args.optim(
        model.parameters(),
        lr=train_args.lr,
        **train_args.optim_kwargs if train_args.optim_kwargs is not None else {},
    )
    if train_args.lr_sched:
        scheduler = train_args.lr_sched(
            optim,
            **train_args.lr_sched_kwargs if train_args.lr_sched_kwargs is not None else {},
        )

    base_lr = train_args.lr  # save for warmup scaling
    epoch = 0
    counter = 0
    _minimum = torch.tensor(10 ** 6).to(device=train_args.device)

    while True:

        # Warmup learning rate manually
        if epoch < train_args.warm_up:
            warmup_factor = (epoch + 1) / train_args.warm_up
            for param_group in optim.param_groups:
                param_group["lr"] = base_lr * warmup_factor

        avg_train_loss, train_acc = train_loop(
            model=model,
            optim=optim,
            fn_loss=train_args.fn_loss,
            dataloader=train_loader,
            device=train_args.device
        )

        avg_test_loss, val_acc = test_loop(
            model=model,
            fn_loss=train_args.fn_loss,
            dataloader=test_loader,
            device=train_args.device
        )

        run.log({
            "train_loss": avg_train_loss,
            "train_acc": train_acc,
            "test_loss": avg_test_loss,
            "val_acc": val_acc,
            "lr": optim.param_groups[0]["lr"],  # log current learning rate
        })

        logger.info(f"Epoch: {epoch}. Average train loss: {avg_train_loss:.4f}, Train acc: {train_acc:.4f}. \
                    Average test loss: {avg_test_loss:.4f}, Val acc: {val_acc:.4f}")

        # monitor based on losses
        if train_args.monitor == "train":
            _current = avg_train_loss
        elif train_args.monitor == "validation":
            _current = avg_test_loss

        if _current < _minimum:
            _minimum = _current
            if train_args.max_epochs != -1 and epoch != train_args.max_epochs:
                if not train_args.sweep:
                    torch.save(model, f"{train_args.save_folder}/{train_args.ckpt_name}_best.pt")
            else:
                counter = 0
        else:
            if train_args.max_epochs == -1:
                logger.info("Incrementing counter by one.")
                counter += 1
                if counter - 1 == train_args.patience:
                    logger.info(f"Model {train_args.monitor} loss no longer decreasing!")
                    logger.info("Saving and exiting training loop.")

                    if not train_args.sweep:
                        torch.save(model, f"{train_args.save_folder}/{train_args.ckpt_name}.pt")
                    run.finish()
                    return

        if train_args.max_epochs != -1 and epoch == train_args.max_epochs:
            logger.info("Max epochs reached.")
            if train_args.save_final:
                logger.info("Saving final model set to True.")
                if not train_args.sweep:
                    torch.save(model, f"{train_args.save_folder}/{train_args.ckpt_name}_final.pt")
            run.finish()
            return

        epoch += 1

        # Step the scheduler only after warmup
        if train_args.lr_sched and epoch > train_args.warm_up:
            scheduler.step()


def train_loop(
        model: Module,
        optim: Module,
        fn_loss: Callable[[Tensor, Tensor], Tensor],
        dataloader: Module,
        device: str,
) -> tuple[Tensor, float]:
    """Run through one epoch of training."""
    avg_train_loss = []
    correct = 0
    total = 0

    model.train()

    for batch, (x, y) in enumerate(dataloader):
        pred = model(x)
        loss = fn_loss(pred, y)

        loss.backward()
        optim.step()
        optim.zero_grad()

        avg_train_loss.append(loss)

        # Compute accuracy
        preds = torch.argmax(pred, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)

    avg_train_loss = torch.mean(torch.stack(avg_train_loss))
    train_acc = correct / total
    return avg_train_loss, train_acc


def test_loop(
        model: Module,
        fn_loss: Callable[[Tensor, Tensor], Tensor],
        dataloader: Module,
        device: str,
) -> tuple[Tensor, float]:
    "Run through one epoch of the test set."
    correct = 0
    total = 0

    model.eval()
    with torch.no_grad():
        x,y = next(iter(dataloader))
        pred = model(x)
        loss = fn_loss(pred, y)
        # Compute accuracy
        preds = torch.argmax(pred, dim=1)
        correct += (preds == y).sum().item()
        total += y.size(0)
        val_acc = correct / total
    return loss, val_acc

def cache_data(
        data_loader: Module,
        device: str,
        shuffle: bool=True,
        num_workers: int=0,
) -> Module:

    images = []
    labels = []
    for x, y in data_loader:
        images.append(x.to(device))
        labels.append(y.to(device))

    x = torch.cat(images)
    y = torch.cat(labels)
    dataset = torch.utils.data.TensorDataset(x, y)
    data_loader = torch.utils.data.DataLoader(
        dataset,
        shuffle=shuffle,
        batch_size=data_loader.batch_size,
        num_workers=num_workers,
    )
    return data_loader
