"""Common CLI utilities for training
"""

from __future__ import annotations

import os
import json
import logging
import subprocess
from typing import Optional, Dict, Any

import torch

# Games
from src.games.battlefield_duel import BattlefieldDuel, BattlefieldDuelSquad2, BattlefieldDuelSquadAmmo


# TB handler
from src.utils import BatchingTensorBoardHandler


def select_game(key: str):
    """Return the Game class for a CLI key.

    Supports aliases used across scripts.
    Raises ValueError on unknown key.
    """
    k = (key or "").lower()
    
    if k in ("battlefield_duel", "duel", "bf_duel"):
        return BattlefieldDuel
    if k in ("battlefield_duel_squad", "duel_squad", "bf_duel_squad"):
        return BattlefieldDuelSquad2
    if k in ("battlefield_duel_squad_ammo", "duel_squad_ammo", "bf_duel_squad_ammo"):
        return BattlefieldDuelSquadAmmo


    raise ValueError(f"Unknown game '{key}'.")


def get_git_info() -> dict:
    """Return Git metadata for current repo (best-effort).

    Keys: commit, commit_short, branch, describe, dirty (bool or None).
    Returns empty dict if not a git repo or git is unavailable.
    """
    info: dict = {}
    try:
        def _run(cmd):
            return subprocess.check_output(cmd, stderr=subprocess.DEVNULL).decode().strip()
        info['commit'] = _run(['git', 'rev-parse', 'HEAD'])
        try:
            info['commit_short'] = _run(['git', 'rev-parse', '--short', 'HEAD'])
        except Exception:
            info['commit_short'] = info.get('commit', '')[:7]
        try:
            info['branch'] = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
        except Exception:
            pass
        try:
            info['describe'] = _run(['git', 'describe', '--tags', '--always', '--dirty'])
        except Exception:
            pass
        try:
            dirty_out = subprocess.check_output(['git', 'status', '--porcelain'], stderr=subprocess.DEVNULL)
            info['dirty'] = bool(dirty_out.strip())
        except Exception:
            info['dirty'] = None
    except Exception:
        # Not a git repo or git not installed
        return {}
    return info


def setup_logger_and_tb(
    name: str,
    logdir: str,
    level_str: str = "INFO",
    flush_interval_s: int = 60,
):
    """Create/configure a module logger and attach a TensorBoard handler.

    - Ensures log directory exists.
    - Disables propagation to root to avoid duplicate logs.
    - Adds a console handler if none present.
    - Adds a BatchingTensorBoardHandler and returns its writer.
    """
    os.makedirs(logdir or ".", exist_ok=True)
    logger = logging.getLogger(name)
    level = getattr(logging, (level_str or "INFO").upper(), logging.INFO)
    logger.setLevel(level)
    logger.propagate = False

    # TensorBoard handler (only once)
    if not any(isinstance(h, BatchingTensorBoardHandler) for h in logger.handlers):
        tb_handler = BatchingTensorBoardHandler(flush_interval_s=flush_interval_s, log_dir=logdir)
        tb_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(threadName)s: %(message)s"))
        logger.addHandler(tb_handler)
        writer = tb_handler.writer
    else:
        # Reuse existing
        tb_handler = next(h for h in logger.handlers if isinstance(h, BatchingTensorBoardHandler))
        writer = tb_handler.writer

    # Console handler (only once)
    if not any(isinstance(h, logging.StreamHandler) for h in logger.handlers):
        console = logging.StreamHandler()
        console.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(threadName)s: %(message)s"))
        logger.addHandler(console)

    return logger, writer


def resolve_device(request: Optional[str], logger: logging.Logger) -> str:
    """Resolve a device string, honoring requests and falling back safely.

    Supports: 'cpu', 'cuda', 'cuda:N', or a GPU name substring.
    """
    if not request:
        return "cpu"
    req = str(request).strip()
    if req.lower() == "cpu":
        return "cpu"

    if req.lower().startswith("cuda"):
        if not torch.cuda.is_available():
            logger.warning("CUDA requested ('%s') but not available; falling back to CPU.", req)
            return "cpu"
        # Parse index if provided
        idx = None
        if ":" in req:
            try:
                idx = int(req.split(":", 1)[1])
            except Exception:
                logger.warning("Invalid CUDA device string '%s'; will try default CUDA device.", req)
        if idx is not None:
            if idx < 0 or idx >= torch.cuda.device_count():
                logger.warning(
                    "Requested CUDA index %s out of range (count=%d); falling back to CPU.",
                    idx,
                    torch.cuda.device_count(),
                )
                return "cpu"
            try:
                torch.cuda.set_device(idx)
                return f"cuda:{idx}"
            except Exception as e:  # pragma: no cover - hardware dependent
                logger.warning("Failed to set CUDA device to index %s (%s); falling back to CPU.", idx, e)
                return "cpu"
        # No index: try 0
        try:
            torch.cuda.set_device(0)
            return "cuda:0"
        except Exception as e:  # pragma: no cover - hardware dependent
            logger.warning("Failed to select CUDA device (%s); falling back to CPU.", e)
            return "cpu"

    # Try GPU name match
    if torch.cuda.is_available():
        target = req.lower()
        try:
            for i in range(torch.cuda.device_count()):
                name = torch.cuda.get_device_name(i)
                if name and target in name.lower():
                    try:
                        torch.cuda.set_device(i)
                        logger.info("Matched GPU name '%s' -> cuda:%d (%s)", req, i, name)
                        return f"cuda:{i}"
                    except Exception as e:  # pragma: no cover
                        logger.warning(
                            "Matched GPU '%s' at index %d but failed to set device (%s); continuing search.",
                            req,
                            i,
                            e,
                        )
            logger.warning("No CUDA device matched name '%s'; falling back to CPU.", req)
        except Exception as e:  # pragma: no cover
            logger.warning("Error while searching CUDA devices for '%s' (%s); falling back to CPU.", req, e)
    else:
        logger.warning("GPU name '%s' requested but CUDA not available; falling back to CPU.", req)
    return "cpu"


def log_hparams(writer, args_dict: Dict[str, Any], write_text: bool = True, add_hparams_plugin: bool = True):
    """Log CLI args into TensorBoard as scalars and text.

    Note: Enabling add_hparams_plugin may create an extra timestamped run directory
    (separate TB event subfolder) for the hparams session. Set add_hparams_plugin=False
    to keep all logs in the main run directory.
    """
    if write_text:
        try:
            writer.add_text("args/json", "```json\n" + json.dumps(args_dict, indent=2, sort_keys=True) + "\n```", 0)
        except Exception:
            pass

    for k, v in (args_dict or {}).items():
        try:
            if isinstance(v, (int, float, bool)):
                writer.add_scalar(f"hparams/{k}", float(v), 0)
            else:
                writer.add_text(f"hparams/{k}", str(v), 0)
        except Exception:
            # Non-fatal
            continue

    if add_hparams_plugin:
        try:
            hparams = {k: (v if isinstance(v, (int, float, bool, str)) else str(v)) for k, v in (args_dict or {}).items()}
            writer.add_hparams(hparams, {"init/step": 0})
        except Exception:
            # Optional, ignore failures
            pass
