"""Run a small training loop and save models.

Runs a short training loop and saves models either after each batch or after an
epoch.
"""

from __future__ import annotations

from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from torch import Tensor
    from setup.configuration import TrainArgs

import os
import torch

from env.directories import create_folder

from helpers.logger import get_logger

logger = get_logger()
logger.level = 20

def tiny_train_loop(
    train_args: TrainArgs,
    max_batch: int,
    model_folder: str,
    *,
    save_on_epoch: bool=False,
) -> tuple[Tensor, dict[str, Tensor]]:
    
    model = train_args.model.to(device=train_args.device)
    train_loader, test_loader = train_args.dataloaders
    optim = train_args.optim(
        model.parameters(),
        lr=train_args.lr,
        **train_args.optim_kwargs if train_args.optim_kwargs is not None else {},
    )
    fn_loss = train_args.fn_loss
     
    avg_train_loss = []
    outer_batch = 0
    epoch = 0
    model.train()

    while outer_batch < max_batch:
        for batch, (x, y) in enumerate(train_loader):
            if outer_batch == max_batch:
                return
            
            if not save_on_epoch:
                batch_folder = f"batch_{batch}"
                batch_folder = os.path.join(model_folder, batch_folder)
                if not os.path.exists(batch_folder):
                    create_folder(path=batch_folder, safe_mode=False)
                ckpt_file = os.path.join(
                    batch_folder,
                    f"batch-{outer_batch}_trial-{train_args.ckpt_name}"
                )
                torch.save(model, ckpt_file + ".pt")

            pred = model.forward(x.to(train_args.device))
            loss = fn_loss(pred, y.to(train_args.device))
            avg_train_loss.append(loss)
            
            loss.backward()
            optim.step()
            optim.zero_grad()

            outer_batch += 1
        
        if save_on_epoch:
            epoch_folder = f"epoch_{epoch}"
            epoch_folder = os.path.join(model_folder, epoch_folder)
            if not os.path.exists(epoch_folder):
                create_folder(path=epoch_folder, safe_mode=False)
            ckpt_file = os.path.join(
                epoch_folder,
                f"epoch-{epoch}_trial-{train_args.ckpt_name}"
            )
            torch.save(model, ckpt_file + ".pt")
        epoch += 1