import sys

from lightning_fabric import Fabric
from torch import Tensor, nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard.writer import SummaryWriter

from smlm import utils
from smlm.loops.calibration_loop import calibration_loop
from smlm.loops.training_loop import training_loop
from smlm.loops.validation_loop import validation_loop


def training_step_loop(
    fabric: Fabric,
    model: nn.Module,
    training_module: nn.Module,
    opt: Optimizer,
    scheduler: LRScheduler,
    dl_train: DataLoader,
    dl_calib: DataLoader,
    dl_val: DataLoader,
    watched_metric: str,
    log_dir: str,
    begin_epoch: int = 0,
    begin_step: int = 0,
    n_epochs: int = -1,
    n_accum_steps: int = 1,
    patience: int = -1,
    first_epoch_done_callback: callable = None,
):
    """Train a model. Training step is defined in training_module"""
    step = begin_step
    best_metric = float("-inf")
    loss_history = []
    best_epoch = begin_epoch

    logger = None  # delayed initialization
    path_last_ckpt, path_best_ckpt = utils.logs.get_ckpts_path(log_dir)

    n_epochs = n_epochs if n_epochs >= 0 else sys.maxsize**10
    for epoch in range(begin_epoch, n_epochs):
        # train calib test
        train_metrics, step = training_loop(
            fabric=fabric,
            dl=dl_train,
            n_accum_steps=n_accum_steps,
            opt=opt,
            scheduler=scheduler,
            step=step,
            training_module=training_module,
        )
        loss_history.append(train_metrics["loss"])

        dl_train.dataset.increment_seed()
        if dl_calib is not None:
            thresholds = calibration_loop(
                fabric=fabric, model=model, dl=dl_calib, metric=watched_metric
            )
            model.set_thresholds(thresholds)
        val_metrics = validation_loop(
            fabric=fabric, model=model, dl=dl_val, wooblecorr=True
        )

        # callback
        if epoch == begin_epoch:
            first_epoch_done_callback()

        # log
        logs = {"epoch": epoch, "step": step}
        if dl_calib is not None:
            logs["threshold"] = thresholds
        logs = logs | train_metrics | val_metrics
        if fabric.is_global_zero:
            if logger is None:
                logger = SummaryWriter(log_dir)
            for key, value in logs.items():
                if isinstance(value, Tensor) and value.nelement() > 1:
                    logger.add_tensor(tag=key, tensor=value, global_step=step)
                else:
                    logger.add_scalar(tag=key, scalar_value=value, global_step=step)
            print(utils.format.kwargs2string(**logs))

        # save
        utils.checkpoint.save_training(
            fabric=fabric,
            path=path_last_ckpt,
            epoch=epoch,
            step=step,
            optimizer_state_dict=opt.state_dict(),
            scheduler_state_dict=scheduler.state_dict(),
            training_module_state_dict=training_module.state_dict(),
        )
        if val_metrics[watched_metric] >= best_metric:
            best_epoch = epoch
            best_metric = val_metrics[watched_metric]
            utils.checkpoint.save_training(
                fabric=fabric,
                path=path_best_ckpt,
                epoch=epoch,
                step=step,
                optimizer_state_dict=opt.state_dict(),
                scheduler_state_dict=scheduler.state_dict(),
                training_module_state_dict=training_module.state_dict(),
            )

        # early stopping
        if patience > 0 and epoch > best_epoch + patience:
            if fabric.is_global_zero:
                print(f"Exceeded patience of {patience} epochs; early stopping.")
            break

        # divergence
        if len(dl_train) != 0 and utils.divergence.detect_divergence(loss_history):
            if fabric.is_global_zero:
                print("Divergence detected. Reverting to previous best weights.")
            utils.checkpoint.load_training(
                fabric=fabric,
                ckpt_path=path_best_ckpt,
                optimizer=opt,
                scheduler=scheduler,
                training_module=training_module,
            )
            loss_history = []

    return best_metric
