from __future__ import annotations

import os
import signal
import subprocess
import sys
import time
from typing import Any, Dict, Optional

from agent.state_manager import StateManager
from python_src.io_utils import read_json_or_default, write_json_atomic


DEFAULT_POLL_INTERVAL_S = 300
DEFAULT_RUNNING_RESUME_INTERVAL_S = 10800
DEFAULT_RESUME_TIMEOUT_S = 900


def _exp_dir(state_manager: StateManager, exp_id: str) -> str:
    return os.path.join(state_manager.experiments_dir, exp_id)


def _pid_file(state_manager: StateManager, exp_id: str) -> str:
    return os.path.join(_exp_dir(state_manager, exp_id), "poller.pid")


def _log_file(state_manager: StateManager, exp_id: str) -> str:
    return os.path.join(_exp_dir(state_manager, exp_id), "poller.log")


def _meta_file(state_manager: StateManager, exp_id: str) -> str:
    return os.path.join(_exp_dir(state_manager, exp_id), "poller.json")


def _safe_write_json(path: str, payload: Dict[str, Any]) -> None:
    write_json_atomic(path, payload, ensure_ascii=True, indent=2)


def _read_json(path: str) -> Dict[str, Any]:
    payload = read_json_or_default(path, {})
    return payload if isinstance(payload, dict) else {}


def _read_pid(path: str) -> Optional[int]:
    if not os.path.exists(path):
        return None
    try:
        with open(path, "r", encoding="utf-8") as handle:
            text = handle.read().strip()
        pid = int(text)
    except Exception:
        return None
    return pid if pid > 0 else None


def _is_pid_alive(pid: Optional[int]) -> bool:
    if not pid:
        return False
    try:
        os.kill(pid, 0)
        return True
    except OSError:
        return False


def poller_status(exp_id: str, state_manager: StateManager) -> Dict[str, Any]:
    pid_path = _pid_file(state_manager, exp_id)
    meta_path = _meta_file(state_manager, exp_id)
    pid = _read_pid(pid_path)
    alive = _is_pid_alive(pid)
    meta = _read_json(meta_path)
    return {
        "pid": pid,
        "alive": alive,
        "pid_file": pid_path,
        "log_file": _log_file(state_manager, exp_id),
        "meta_file": meta_path,
        "meta": meta,
    }


def _ignore_terminal_signals() -> None:
    signal.signal(signal.SIGHUP, signal.SIG_IGN)


def _load_env_file(path: str, env: dict[str, str]) -> None:
    """Load KEY=VALUE lines from a dotenv-style file into *env*."""
    if not os.path.isfile(path):
        return
    try:
        with open(path, "r", encoding="utf-8") as fh:
            for line in fh:
                line = line.strip()
                if not line or line.startswith("#"):
                    continue
                if "=" not in line:
                    continue
                key, _, value = line.partition("=")
                key = key.strip()
                value = value.strip()
                if not key:
                    continue
                if len(value) >= 2 and value[0] == value[-1] and value[0] in ('"', "'"):
                    value = value[1:-1]
                env.setdefault(key, value)
    except OSError:
        pass


def _spawn_detached_poller(
    cmd: list[str],
    *,
    log_path: str,
    repo_root: str,
) -> subprocess.Popen:
    env = os.environ.copy()
    _load_env_file(os.path.join(repo_root, ".env.local"), env)
    _load_env_file(os.path.join(repo_root, ".env"), env)
    env.setdefault("PYTHONUNBUFFERED", "1")
    with open(log_path, "ab") as log_handle:
        kwargs: Dict[str, Any] = {
            "stdin": subprocess.DEVNULL,
            "stdout": log_handle,
            "stderr": log_handle,
            "cwd": repo_root,
            "env": env,
            "close_fds": True,
        }
        if os.name == "nt":
            creationflags = 0
            creationflags |= getattr(subprocess, "CREATE_NEW_PROCESS_GROUP", 0)
            creationflags |= getattr(subprocess, "DETACHED_PROCESS", 0)
            kwargs["creationflags"] = creationflags
        else:
            kwargs["start_new_session"] = True
            kwargs["preexec_fn"] = _ignore_terminal_signals
        return subprocess.Popen(cmd, **kwargs)


def ensure_poller_running(
    exp_id: str,
    state_manager: StateManager,
    interval_s: Optional[int] = None,
    running_resume_interval_s: Optional[int] = None,
) -> Dict[str, Any]:
    exp_dir = _exp_dir(state_manager, exp_id)
    os.makedirs(exp_dir, exist_ok=True)

    pid_path = _pid_file(state_manager, exp_id)
    existing_pid = _read_pid(pid_path)
    if _is_pid_alive(existing_pid):
        return {
            "started": False,
            "pid": existing_pid,
            "reason": "already_running",
            "pid_file": pid_path,
            "log_file": _log_file(state_manager, exp_id),
            "meta_file": _meta_file(state_manager, exp_id),
        }

    poll_interval = int(interval_s or os.getenv("ANUM_POLLER_INTERVAL_S", DEFAULT_POLL_INTERVAL_S))
    running_interval = int(
        running_resume_interval_s
        or os.getenv("ANUM_POLLER_RUNNING_RESUME_INTERVAL_S", DEFAULT_RUNNING_RESUME_INTERVAL_S)
    )
    resume_timeout_s = int(os.getenv("ANUM_POLLER_RESUME_TIMEOUT_S", DEFAULT_RESUME_TIMEOUT_S))
    poll_interval = max(1, poll_interval)
    running_interval = max(1, running_interval)
    resume_timeout_s = max(1, resume_timeout_s)

    cmd = [
        sys.executable,
        "-m",
        "agent.poller",
        "--exp_id",
        exp_id,
        "--interval_s",
        str(poll_interval),
        "--running_resume_interval_s",
        str(running_interval),
        "--resume_timeout_s",
        str(resume_timeout_s),
    ]

    log_path = _log_file(state_manager, exp_id)
    proc = _spawn_detached_poller(
        cmd,
        log_path=log_path,
        repo_root=state_manager.repo_root,
    )

    with open(pid_path, "w", encoding="utf-8") as handle:
        handle.write(str(proc.pid))
        handle.write("\n")

    meta_path = _meta_file(state_manager, exp_id)
    _safe_write_json(
        meta_path,
        {
            "exp_id": exp_id,
            "pid": proc.pid,
            "cmd": cmd,
            "pid_file": pid_path,
            "log_file": log_path,
            "interval_s": poll_interval,
            "running_resume_interval_s": running_interval,
            "resume_timeout_s": resume_timeout_s,
            "repo_root": state_manager.repo_root,
            "started_at": int(time.time()),
        },
    )

    return {
        "started": True,
        "pid": proc.pid,
        "pid_file": pid_path,
        "log_file": log_path,
        "meta_file": meta_path,
        "interval_s": poll_interval,
        "running_resume_interval_s": running_interval,
        "resume_timeout_s": resume_timeout_s,
    }


def stop_poller(exp_id: str, state_manager: StateManager) -> Dict[str, Any]:
    pid_path = _pid_file(state_manager, exp_id)
    pid = _read_pid(pid_path)
    if not _is_pid_alive(pid):
        if os.path.exists(pid_path):
            try:
                os.remove(pid_path)
            except OSError:
                pass
        return {"stopped": False, "pid": pid, "reason": "not_running", "status": poller_status(exp_id, state_manager)}

    try:
        os.kill(pid, signal.SIGTERM)
    except OSError:
        return {"stopped": False, "pid": pid, "reason": "kill_failed", "status": poller_status(exp_id, state_manager)}

    return {"stopped": True, "pid": pid, "status": poller_status(exp_id, state_manager)}
