from __future__ import annotations
import logging
from pathlib import Path
from typing import List, Optional, Sequence, Union, Iterable

import torch
from torch import nn
from torch.nn.utils import clip_grad_norm_
from torch.utils.data import DataLoader


def train_model(
    model: nn.Module,
    dataset,
    loss_fn: nn.Module,
    *,
    epochs: int = 5,
    lr: float = 1e-2,
    weight_decay: float = 0.0,
    batch_size: int = 128,
    num_workers: int = 0,
    max_grad_norm: Optional[float] = 5,
    stop_on_nonfinite: bool = True,
    logger: Optional[logging.Logger] = None,
    checkpoint_dir: Optional[Union[str, Path]] = None,
    checkpoint_epochs: Optional[Union[int, Sequence[int]]] = None,
    checkpoint_fractions: Optional[Union[float, Sequence[float]]] = None,
    save_initial: bool = False,
) -> dict:
    """Simple SGD loop with optional grad clipping and NaN/Inf guard.

    If checkpoint_dir is provided, saves checkpoints during training and returns
    their paths in the returned dict under 'checkpoint_paths'.
    """

    device = next(model.parameters()).device
    model.train()
    opt = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)

    grad_clip_count = 0
    nan_batches = 0
    epochs_run = 0

    checkpoint_paths: List[str] = []
    ckpt_dir_path: Optional[Path] = None
    if checkpoint_dir is not None:
        ckpt_dir_path = Path(checkpoint_dir)
        ckpt_dir_path.mkdir(parents=True, exist_ok=True)

    def _as_list(x: Optional[Union[int, float, Sequence[Union[int, float]]]]) -> List[Union[int, float]]:
        if x is None:
            return []
        if isinstance(x, (list, tuple)):
            return list(x)
        return [x]

    ckpt_epochs_set = set(int(e) for e in _as_list(checkpoint_epochs) if int(e) > 0)
    ckpt_fracs = [float(f) for f in _as_list(checkpoint_fractions) if float(f) > 0.0]
    if ckpt_fracs:
        for f in ckpt_fracs:
            ep = int(max(1, round(float(epochs) * float(f))))
            ckpt_epochs_set.add(ep)

    def _save_ckpt(ep: int) -> None:
        if ckpt_dir_path is None:
            return
        out_path = ckpt_dir_path / f"checkpoint_epoch_{int(ep):04d}.pt"
        torch.save(
            {
                "epoch": int(ep),
                "model_state_dict": model.state_dict(),
            },
            out_path,
        )
        checkpoint_paths.append(str(out_path))

    if save_initial:
        _save_ckpt(0)

    for ep in range(epochs):
        running = 0.0
        n = 0
        for b_idx, (xb, yb) in enumerate(loader):
            xb, yb = xb.to(device), yb.to(device)
            opt.zero_grad(set_to_none=True)
            loss = loss_fn(model(xb), yb)
            if not torch.isfinite(loss):
                nan_batches += 1
                if logger is not None:
                    logger.warning("Non-finite loss at epoch %d batch %d; stopping.", ep + 1, b_idx + 1)
                if stop_on_nonfinite:
                    model.eval()
                    return {
                        "epochs_run": epochs_run,
                        "grad_clip_count": grad_clip_count,
                        "nan_batches": nan_batches,
                    }
                continue
            loss.backward()
            if max_grad_norm is not None:
                total_norm = clip_grad_norm_(model.parameters(), max_grad_norm)
                if float(total_norm) > float(max_grad_norm):
                    grad_clip_count += 1
            opt.step()
            running += float(loss.item()) * xb.shape[0]
            n += xb.shape[0]
        epochs_run += 1
        if logger is not None:
            logger.info("train epoch %d/%d loss=%.4f", ep + 1, epochs, running / max(n, 1))

        if (ep + 1) in ckpt_epochs_set:
            _save_ckpt(ep + 1)

    model.eval()
    return {
        "epochs_run": epochs_run,
        "grad_clip_count": grad_clip_count,
        "nan_batches": nan_batches,
        "checkpoint_paths": checkpoint_paths,
    }
