
"""training_utils.py

Utility training loop, evaluation, checkpointing, and deterministic setup.
Designed to be robust and to fall back to simple models if the user's model import fails.
"""
import os
import time
import json
import torch
import random
import numpy as np
from pathlib import Path
from typing import Dict, Any, Optional

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

def save_checkpoint(state: Dict[str, Any], out_dir: str, name: str = "checkpoint.pt"):
    os.makedirs(out_dir, exist_ok=True)
    path = os.path.join(out_dir, name)
    torch.save(state, path)
    return path

def load_checkpoint(path: str, device: str = "cpu"):
    if not os.path.exists(path):
        raise FileNotFoundError(path)
    return torch.load(path, map_location=device)

def train(
    model: torch.nn.Module,
    train_loader,
    val_loader,
    optimizer,
    criterion,
    device: str = "cpu",
    epochs: int = 50,
    scheduler = None,
    out_dir: Optional[str] = None,
    metrics_fn = None,
    verbose: bool = True,
):
    """Generic training loop that records losses and (optional) metrics.

    Returns:
        history: dict with lists for train_loss, val_loss, and metrics if provided.
    """
    os.makedirs(out_dir or ".", exist_ok=True)
    model.to(device)
    history = {"train_loss": [], "val_loss": []}
    if metrics_fn is not None:
        history.update(metrics_fn.metrics_keys())
    best_val = float("inf")
    for epoch in range(1, epochs + 1):
        model.train()
        running = 0.0
        n_seen = 0
        t0 = time.time()
        for batch in train_loader:
            # support TensorDataset (x,y) or (x,)
            if isinstance(batch, (list, tuple)):
                x = batch[0].to(device)
                y = batch[1].to(device) if len(batch) > 1 else None
            else:
                x = batch.to(device); y = None
            optimizer.zero_grad()
            out = model(x) if y is None else model(x, y) if hasattr(model, "__call__") else model(x)
            # out can be dict with 'recon' or raw tensor prediction
            if isinstance(out, dict):
                pred = out.get("recon", out.get("pred", None))
            else:
                pred = out
            loss = criterion(pred, y) if y is not None else criterion(pred, x)
            loss.backward()
            optimizer.step()
            running += float(loss.item()) * (len(x) if hasattr(x, "__len__") else 1)
            n_seen += len(x) if hasattr(x, "__len__") else 1
        train_loss = running / max(1, n_seen)
        history["train_loss"].append(train_loss)

        # validation
        model.eval()
        running = 0.0
        n_seen = 0
        with torch.no_grad():
            for batch in val_loader:
                if isinstance(batch, (list, tuple)):
                    x = batch[0].to(device)
                    y = batch[1].to(device) if len(batch) > 1 else None
                else:
                    x = batch.to(device); y = None
                out = model(x) if y is None else model(x, y) if hasattr(model, "__call__") else model(x)
                if isinstance(out, dict):
                    pred = out.get("recon", out.get("pred", None))
                else:
                    pred = out
                loss = criterion(pred, y) if y is not None else criterion(pred, x)
                running += float(loss.item()) * (len(x) if hasattr(x, "__len__") else 1)
                n_seen += len(x) if hasattr(x, "__len__") else 1
        val_loss = running / max(1, n_seen)
        history["val_loss"].append(val_loss)

        if scheduler is not None:
            try:
                scheduler.step(val_loss)
            except Exception:
                scheduler.step()

        # metrics
        if metrics_fn is not None:
            for k,v in metrics_fn.compute(model, val_loader, device=device).items():
                history.setdefault(k, []).append(v)

        if verbose:
            print(f"Epoch {epoch}/{epochs} — train_loss={train_loss:.4f} val_loss={val_loss:.4f} time={time.time()-t0:.1f}s")

        # checkpoint
        if out_dir and val_loss < best_val:
            best_val = val_loss
            save_checkpoint({
                "epoch": epoch,
                "model_state": model.state_dict(),
                "optimizer_state": optimizer.state_dict(),
                "history": history,
            }, out_dir, name=f"best_epoch_{epoch}.pt")
    return history
