# assp/utils/system.py
"""
System utilities: experiment dirs, logging, and small state I/O (rank0-only).
"""

from __future__ import annotations

import logging
import uuid
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, Optional

import torch

# --------------------------- Experiment directory ---------------------------


def _make_run_name(experiment_name: str, timestamp: str, short_id: str) -> str:
    """Return run name like EXP_YYYYMMDD-HHMMSS-ABC123."""
    return f"{experiment_name}_{timestamp}-{short_id}"


def setup_experiment_directory(
    root_dir: str | Path,
    experiment_name: str,
    *,
    accelerator,  # require accelerate.Accelerator
) -> Dict[str, Path]:
    """
    Return identical paths on all ranks; only main-process creates directories.

    约定：
    - 主进程生成 run_name 并写入 root_dir/.latest_run
    - 全体进程 wait_for_everyone() 同步后读取相同 run_name
    """
    root_dir = Path(root_dir).resolve()
    marker = root_dir / ".latest_run"

    if accelerator.is_main_process:
        ts = datetime.now().strftime("%Y%m%d-%H%M%S")
        sid = uuid.uuid4().hex[:6].upper()
        run_name = _make_run_name(experiment_name, ts, sid)
        experiment_root = root_dir / run_name

        paths = {
            "root": experiment_root,
            "checkpoints": experiment_root / "checkpoints",
            "logs": experiment_root / "logs",
            "visualizations": experiment_root / "visualizations",
        }
        for p in paths.values():
            p.mkdir(parents=True, exist_ok=True)

        root_dir.mkdir(parents=True, exist_ok=True)
        with open(marker, "w", encoding="utf-8") as f:
            f.write(run_name)

    accelerator.wait_for_everyone()

    with open(marker, "r", encoding="utf-8") as f:
        run_name = f.read().strip()
    experiment_root = root_dir / run_name

    paths = {
        "root": experiment_root,
        "checkpoints": experiment_root / "checkpoints",
        "logs": experiment_root / "logs",
        "visualizations": experiment_root / "visualizations",
    }

    accelerator.wait_for_everyone()
    return paths


# --------------------------- Logging ---------------------------

# prevent duplicate handlers per (dir, rank)
_LOGGER_CACHE: Dict[str, logging.Logger] = {}


def create_logger(
    log_directory: str | Path,
    *,
    logger_name: Optional[str] = None,
    accelerator,  # require accelerate.Accelerator
) -> logging.Logger:
    """
    Rank0 logs to file+stdout; non-main ranks are silent via NullHandler.
    """
    log_directory = Path(log_directory).resolve()
    rank = int(accelerator.process_index)
    cache_key = f"{str(log_directory)}|rank={rank}"

    if cache_key in _LOGGER_CACHE:
        return _LOGGER_CACHE[cache_key]

    name = logger_name or cache_key
    logger = logging.getLogger(name)
    logger.setLevel(logging.INFO)
    logger.propagate = False

    if logger.handlers:
        for h in list(logger.handlers):
            try:
                h.close()
            finally:
                logger.removeHandler(h)

    fmt = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S")

    if accelerator.is_main_process:
        log_directory.mkdir(parents=True, exist_ok=True)
        fh = logging.FileHandler(log_directory / "train.log", mode="a", encoding="utf-8")
        fh.setFormatter(fmt)
        logger.addHandler(fh)

        sh = logging.StreamHandler()
        sh.setFormatter(fmt)
        logger.addHandler(sh)
    else:
        logger.addHandler(logging.NullHandler())  # fully silent on non-main ranks

    _LOGGER_CACHE[cache_key] = logger
    return logger


# --------------------------- Small state I/O (rank0-only) ---------------------------


def save_state_on_master(
    state: Dict[str, Any],
    path: str | Path,
    *,
    accelerator,  # require accelerate.Accelerator
) -> None:
    """Save an arbitrary small state only on main process (e.g., metrics snapshot or JSON)."""
    if accelerator.is_main_process:
        torch.save(state, str(Path(path)))


# --------------------------- Formatting ---------------------------


def format_large_number(number: int | float) -> str:
    """Format a number with K/M/G suffix."""
    if number >= 1e9:
        return f"{number / 1e9:.2f} G"
    if number >= 1e6:
        return f"{number / 1e6:.2f} M"
    if number >= 1e3:
        return f"{number / 1e3:.2f} K"
    return str(int(number))
