"""Run training loop for Pytorch models.

Implements a basic training loop to train Pytorch models. The main method uses
the TrainArgs dataclass (found in setup.configuration.py) as input parameters
for the trainer. """

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch.nn import Module
    from torch import Tensor
    from collections.abc import Callable
    from setup.configuration import TrainArgs

import wandb
import torch

from helpers.logger import get_logger

logger = get_logger()
logger.level = 20

def train(
    train_args: TrainArgs,
) -> None:
    """Run training loops until stopping criteria is met."""
    logger.info(f"Training: {train_args.ckpt_name}")

    if not train_args.sweep:
        run = wandb.init(**train_args.wandb_kwargs)
        wandb.config = train_args.__dict__
    else:
        run = train_args.sw_run

    model = train_args.model.to(device=train_args.device)
    train_loader, test_loader = train_args.dataloaders
    optim = train_args.optim(
        model.parameters(),
        lr=train_args.lr,
        **train_args.optim_kwargs if train_args.optim_kwargs is not None else {},
    )
    if train_args.lr_sched:
        scheduler = train_args.lr_sched(
            optim,
            **train_args.lr_sched_kwargs if train_args.lr_sched_kwargs is not None else {},
        )

    epoch = 0
    counter = 0
    _minimum = torch.tensor(10**6).to(device=train_args.device)
    while True:
        
        avg_train_loss, train_metrics = train_loop(
            model=model,
            optim=optim,
            fn_loss=train_args.fn_loss,
            perform_metrics=train_args.perform_metrics,
            dataloader=train_loader,
            device=train_args.device
        )
        
        avg_test_loss, test_metrics = test_loop(
            model=model,
            fn_loss=train_args.fn_loss,
            perform_metrics=train_args.perform_metrics,
            dataloader=test_loader,
            device=train_args.device
        )

        run.log(
            {
                "train_loss": avg_train_loss,
                "test_loss": avg_test_loss,
                **train_metrics,
                **test_metrics,
            }
        )

        logger.info(f"Epoch: {epoch}. Average train loss: {avg_train_loss}. \
                    Average test loss: {avg_test_loss}")

        if epoch < train_args.warm_up:
            epoch +=1
            continue
        
        # value to monitor for stopping condition
        if train_args.monitor == "train":
            _current = avg_train_loss
        elif train_args.monitor == "validation":
            _current = avg_test_loss

        if _current < _minimum:
            _minimum = _current
            if train_args.max_epochs != -1 and epoch != train_args.max_epochs:
                # save current best model checkpoint
                if not train_args.sweep:
                    torch.save(model, f"{train_args.save_folder}" + f"/{train_args.ckpt_name}_best.pt")
            else:
                counter = 0
        else:
            if train_args.max_epochs == -1:
                logger.info("Incrementing counter by one.")
                counter += 1
                if counter - 1 == train_args.patience:
                    if train_args.max_epochs == -1:
                        logger.info(f"Model {train_args.monitor} loss no longer increasing!")
                        logger.info("Saving and exiting training loop.")
                        
                        if not train_args.sweep:
                            torch.save(model, f"{train_args.save_folder}" + f"/{train_args.ckpt_name}.pt")
                        run.finish()

                        return
                    
        if train_args.max_epochs != -1 and epoch == train_args.max_epochs:
                logger.info("Max epochs reached.")
                if train_args.save_final:
                    logger.info(
                        "Saving final model set to True.",
                    )
                    if not train_args.sweep:
                        torch.save(model, f"{train_args.save_folder}" + f"/{train_args.ckpt_name}_final.pt")
                run.finish()
                return

        epoch += 1
        if train_args.lr_sched:
            scheduler.step()

def train_loop(
    model: Module,
    optim: Module,
    fn_loss: Callable[[Tensor, Tensor], Tensor],
    perform_metrics: None|dict[str, type],
    dataloader: Module,
    device: str,
) -> tuple[Tensor, dict[str, Tensor]]:
    """Run through one epoch of training."""
    avg_train_loss = []
    # initialize performance metric objects
    metrics = {}
    if perform_metrics:
        for m in perform_metrics.keys():
            train_key = "train_" + m
            metrics[train_key] = perform_metrics[m]()

    model.train()
    for batch, (x, y) in enumerate(dataloader):

        pred = model.forward(x.to(device))
        loss = fn_loss(pred, y.to(device))
        avg_train_loss.append(loss)

        # update internal state of metric objects per batch
        if perform_metrics:
            for m in metrics.keys():
                metrics[m].compute(pred, y.to(device))
        
        loss.backward()
        optim.step()
        optim.zero_grad()
        
    avg_train_loss = torch.mean(torch.stack(avg_train_loss))
    # finalize metric objects - finalize should return tensor
    if perform_metrics:
        for m in metrics.keys():
            metrics[m] = metrics[m].finalize()

    return avg_train_loss, metrics

def test_loop(
    model: Module,
    fn_loss: Callable[[Tensor, Tensor], Tensor],
    perform_metrics: None|dict[str, type],
    dataloader: Module,
    device: str,
) -> tuple[Tensor, dict[str, Tensor]]:
    "Run through one epoch of the test set."
    avg_val_loss = []
    # initialize performance metric objects
    metrics = {}
    if perform_metrics:
        for m in perform_metrics.keys():
            test_key = "test_" + m
            metrics[test_key] = perform_metrics[m]()

    model.eval()
    with torch.no_grad():
        for batch, (x, y) in enumerate(dataloader):
            pred = model.forward(x.to(device))
            loss = fn_loss(pred, y.to(device))
            avg_val_loss.append(loss)

            # update internal state of metric objects per batch
            if perform_metrics:
                for m in metrics.keys():
                    metrics[m].compute(pred, y.to(device))

        avg_val_loss = torch.mean(torch.stack(avg_val_loss))

        # finalize metric objects - finalize should return tensor
        if perform_metrics:
            for m in metrics.keys():
                metrics[m] = metrics[m].finalize()
    
    return avg_val_loss, metrics
