"""Training utilities for time-series experiments."""

from __future__ import annotations

import json
from pathlib import Path
from typing import Any, Dict, List, Tuple

import torch
import torch.nn as nn
from torch.utils.data import DataLoader


def _classification_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: torch.optim.Optimizer | None = None,
) -> Tuple[float, float]:
    criterion = nn.CrossEntropyLoss()
    total_loss = 0.0
    total_correct = 0
    total = 0

    train_mode = optimizer is not None
    model.train(train_mode)

    with torch.enable_grad() if train_mode else torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            logits = model(inputs)
            loss = criterion(logits, targets)

            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            batch_size = targets.size(0)
            total_loss += loss.item() * batch_size
            preds = logits.argmax(dim=-1)
            total_correct += (preds == targets).sum().item()
            total += batch_size

    avg_loss = total_loss / max(total, 1)
    avg_acc = total_correct / max(total, 1)
    return avg_loss, avg_acc


def train_classification(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int = 10,
    lr: float = 1e-3,
    verbose: bool = False,
) -> Dict[str, List[Dict[str, float]]]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.to(device)
    history: List[Dict[str, float]] = []

    for epoch in range(1, epochs + 1):
        train_loss, train_acc = _classification_epoch(model, train_loader, device, optimizer)
        val_loss, val_acc = _classification_epoch(model, val_loader, device, optimizer=None)
        history.append(
            {
                "epoch": epoch,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "val_loss": val_loss,
                "val_acc": val_acc,
            }
        )
        if verbose:
            print(
                f"[train] epoch={epoch}/{epochs} "
                f"train_loss={train_loss:.4f} train_acc={train_acc:.4f} "
                f"val_loss={val_loss:.4f} val_acc={val_acc:.4f}",
                flush=True,
            )

    return {"history": history}


def _regression_epoch(
    model: nn.Module,
    loader: DataLoader,
    device: torch.device,
    optimizer: torch.optim.Optimizer | None = None,
) -> float:
    criterion = nn.MSELoss()
    total_loss = 0.0
    total = 0

    train_mode = optimizer is not None
    model.train(train_mode)
    dtype = next(model.parameters()).dtype

    with torch.enable_grad() if train_mode else torch.no_grad():
        for inputs, targets in loader:
            inputs = inputs.to(device=device, dtype=dtype)
            targets = targets.to(device=device, dtype=dtype)

            preds = model(inputs)
            loss = criterion(preds, targets)

            if train_mode:
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            batch_size = targets.size(0)
            total_loss += loss.item() * batch_size
            total += batch_size

    return total_loss / max(total, 1)


def train_regression(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    device: torch.device,
    epochs: int = 200,
    lr: float = 1e-3,
    verbose: bool = False,
) -> Dict[str, List[Dict[str, float]]]:
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    model.to(device)
    history: List[Dict[str, float]] = []

    for epoch in range(1, epochs + 1):
        train_loss = _regression_epoch(model, train_loader, device, optimizer)
        val_loss = _regression_epoch(model, val_loader, device, optimizer=None)
        history.append({"epoch": epoch, "train_loss": train_loss, "val_loss": val_loss})
        if verbose:
            print(
                f"[train] epoch={epoch}/{epochs} train_loss={train_loss:.6f} val_loss={val_loss:.6f}",
                flush=True,
            )

    return {"history": history}


def log_run(path: str, payload: Dict[str, Any]) -> None:
    """
    Append a JSON line to `path` containing the given payload.
    """
    path_obj = Path(path)
    path_obj.parent.mkdir(parents=True, exist_ok=True)
    with path_obj.open("a", encoding="utf-8") as f:
        f.write(json.dumps(payload))
        f.write("\n")
