from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Optional, Tuple, Iterable

import glob
import json
import logging
import os
import re
import time

import torch
from tabulate import tabulate


def flatten_params(module: torch.nn.Module) -> torch.Tensor:
    parts = [p.detach().reshape(-1) for p in module.parameters()]
    if not parts:
        return torch.empty(0)
    return torch.cat(parts, dim=0)


def jacobian_fn(apply_fn, params: Tuple[torch.nn.Parameter, ...], *args) -> torch.Tensor:
    raise NotImplementedError("In torch, compute grads on a real forward/loss using autograd. Avoid passing raw params tuples.")


def ntk_fn(apply_fn, params: Tuple[torch.nn.Parameter, ...], *args) -> torch.Tensor:
    raise NotImplementedError("In torch, compute NTK via per-sample Jacobians if needed; not included here.")


def count_params(model: torch.nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


def _atomic_save(obj: Any, path: Path):
    path.parent.mkdir(parents=True, exist_ok=True)
    tmp = path.with_suffix(path.suffix + ".tmp")
    torch.save(obj, tmp)
    os.replace(tmp, path)


def save_checkpoint(
    *,
    workdir: str | Path,
    step: int,
    model: torch.nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    state: Optional[dict] = None,
    keep: int = 5,
    name: str = "state.pt",
):
    wd = Path(workdir)
    ckpt_dir = wd / "checkpoints" / str(int(step))
    ckpt_path = ckpt_dir / name

    payload: Dict[str, Any] = {
        "step": int(step),
        "model": model.state_dict(),
    }
    if optimizer is not None:
        payload["optimizer"] = optimizer.state_dict()
    if state is not None:
        payload["state"] = state

    _atomic_save(payload, ckpt_path)

    if keep and keep > 0:
        root = wd / "checkpoints"
        dirs = [p for p in root.iterdir() if p.is_dir() and p.name.isdigit()]
        steps = sorted(int(p.name) for p in dirs)
        for s in steps[:-keep]:
            try:
                for f in (root / str(s)).glob("*"):
                    f.unlink(missing_ok=True)
                (root / str(s)).rmdir()
            except Exception:
                pass


def restore_checkpoint(
    *,
    workdir: str | Path,
    model: torch.nn.Module,
    optimizer: Optional[torch.optim.Optimizer] = None,
    step: int | str | None = "latest",
    name: str = "state.pt",
    map_location: Optional[str | torch.device] = None,
) -> Dict[str, Any]:
    wd = Path(workdir) / "checkpoints"
    dirs = [p for p in wd.iterdir() if p.is_dir() and p.name.isdigit()]
    if not dirs:
        raise FileNotFoundError(f"No checkpoints found in {wd}")

    if step is None or step == "latest":
        step = max(int(p.name) for p in dirs)

    ckpt_path = wd / str(int(step)) / name
    if not ckpt_path.exists():
        raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")

    payload = torch.load(ckpt_path, map_location=map_location)
    model.load_state_dict(payload["model"], strict=True)

    if optimizer is not None and "optimizer" in payload:
        optimizer.load_state_dict(payload["optimizer"])

    return payload


def peek_checkpoint(path: str | Path):
    payload = torch.load(Path(path), map_location="cpu")
    model_sd = payload.get("model", {})
    def show(sd, prefix=""):
        for k, v in sd.items():
            if isinstance(v, torch.Tensor):
                print(f"{prefix}{k}: {list(v.shape)} {v.dtype}")
            else:
                print(f"{prefix}{k}: {type(v)}")
    show(model_sd)


class CustomJSONEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, torch.Tensor):
            return obj.detach().cpu().tolist()
        if isinstance(obj, Path):
            return str(obj)
        return super().default(obj)


def save_config(config: Any, workdir: str | Path, name: Optional[str] = None):
    wd = Path(workdir)
    wd.mkdir(parents=True, exist_ok=True)
    fname = (name or "config") + ".json"
    path = wd / fname
    if hasattr(config, "to_dict"):
        payload = config.to_dict()
    elif isinstance(config, dict):
        payload = config
    else:
        raise TypeError("config must be dict-like or have .to_dict()")
    with open(path, "w") as f:
        json.dump(payload, f, cls=CustomJSONEncoder, indent=4)


def get_log_keys(log_dict: Dict[str, Any]):
    out = []
    for k in log_dict.keys():
        if k.endswith("_loss") or k.endswith("_error") or k.startswith("scale_"):
            out.append(k)
    return out


class Logger:
    def __init__(self, name: str = "main", level=logging.INFO):
        self.logger = logging.getLogger(name)
        self.logger.handlers.clear()
        self.logger.setLevel(level)
        self.logger.propagate = False

        fmt = logging.Formatter("[%(asctime)s] %(message)s", datefmt="%H:%M:%S")
        h = logging.StreamHandler()
        h.setFormatter(fmt)
        h.setLevel(level)
        self.logger.addHandler(h)

    def info(self, message: str):
        self.logger.info(message)

    def log_iter(self, step: int, start_time: float, end_time: float, log_dict: Dict[str, Any]):
        keys = get_log_keys(log_dict)
        rows = [[k, "{:.3e}".format(float(log_dict[k]))] for k in keys]

        message = tabulate(
            rows,
            headers=[f"Iter: {step:3d}", f"Time: {end_time - start_time:.3f}"],
            tablefmt="simple",
            numalign="right",
            disable_numparse=True,
        )

        header_length = len(message.split("\n")[0]) + 2
        dashed = "-" * header_length
        message = dashed + "\n" + message

        for line in message.split("\n"):
            self.logger.info(line)


class Collection(dict):
    __slots__ = ()

    def __getattr__(self, key):
        try:
            return self[key]
        except KeyError:
            raise AttributeError(f"No such config key: {key}")

    def __setattr__(self, key, value):
        if key.startswith("_"):
            return super().__setattr__(key, value)
        self[key] = self._wrap(value)

    def __delattr__(self, key):
        try:
            del self[key]
        except KeyError:
            raise AttributeError(f"No such config key: {key}")

    @staticmethod
    def _wrap(value):
        if isinstance(value, dict):
            return Collection({k: Collection._wrap(v) for k, v in value.items()})
        if isinstance(value, list):
            return [Collection._wrap(v) for v in value]
        return value

    @classmethod
    def from_dict(cls, d: dict):
        return cls({k: cls._wrap(v) for k, v in d.items()})


@dataclass(frozen=True)
class LoadedRuns:
    run_root: str
    cfg: Any = None
    model: Any = None
    step: Any = None


def _parse_step(path: str) -> Optional[int]:
    b = os.path.basename(path.rstrip("/"))
    m = re.search(r"(\d+)(?!.*\d)", b)
    return int(m.group(1)) if m else None


def list_checkpoint_steps(run_root: str) -> list[int]:
    ckpt_root = os.path.join(run_root, "checkpoints")
    steps = []
    for p in glob.glob(os.path.join(ckpt_root, "*")):
        s = _parse_step(p)
        if s is not None:
            steps.append(s)
    return sorted(set(steps))


def load_checkpoint_as_run(trainer_cls, run_root: str, step: int, device: Optional[torch.device] = None):
    t = trainer_cls.from_run_dir(run_root, device=device)
    t.run_dir = run_root
    t.build()
    t.restore(run_root, step=step)
    return LoadedRuns(run_root=run_root, step=step, cfg=t.raw_cfg, model=t.model)


def load_all_checkpoints_as_runs(trainer_cls, run_root: str, device: Optional[torch.device] = None, steps: Optional[list[int]] = None):
    if steps is None:
        steps = list_checkpoint_steps(run_root)
    return [load_checkpoint_as_run(trainer_cls, run_root, s, device=device) for s in steps]
