from __future__ import annotations

import fcntl
import json
import math
import os
import time
from typing import Any, Dict, List, Optional

from agent.state import AgentState
from python_src.io_utils import read_json, write_json_atomic


class StateManager:
    """Manages persistent storage of agent state for async workflows.

    State is stored in experiments/<exp_id>/state.json.
    Uses file locking for safe concurrent access.
    """

    def __init__(self, experiments_dir: Optional[str] = None, repo_root: Optional[str] = None):
        if repo_root is None:
            repo_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
        self.repo_root = repo_root
        self.experiments_dir = experiments_dir or os.path.join(repo_root, "experiments")

    def _exp_dir(self, exp_id: str) -> str:
        return os.path.join(self.experiments_dir, exp_id)

    def _state_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "state.json")

    def _lock_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "state.lock")

    def _config_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "config.json")

    def _summary_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "summary.json")

    def _final_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "final.json")

    def _supervision_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "supervision.json")

    def _supervision_lock_path(self, exp_id: str) -> str:
        return os.path.join(self._exp_dir(exp_id), "supervision.lock")

    def _resolve_op_root(self) -> str:
        return os.getenv("ANUM_OP_ROOT") or os.path.join(self.repo_root, "op")

    def resolve_op_root(self) -> str:
        return self._resolve_op_root()

    @staticmethod
    def _safe_write_json(path: str, payload: Dict[str, Any]) -> None:
        write_json_atomic(path, payload, ensure_ascii=True, indent=2)

    @staticmethod
    def _read_json(path: str) -> Dict[str, Any]:
        return read_json(path)

    def _read_status_json(self, task_tag: str) -> Optional[Dict[str, Any]]:
        op_root = self._resolve_op_root()
        status_path = os.path.join(op_root, task_tag, "status.json")
        if not os.path.exists(status_path):
            return None
        try:
            return self._read_json(status_path)
        except Exception:
            return None

    @staticmethod
    def _count_ops_from_candidate(candidate: Dict[str, Any]) -> Optional[int]:
        ops = candidate.get("ops")
        if ops is not None:
            try:
                return int(ops)
            except (TypeError, ValueError):
                pass
        node_list = candidate.get("nodes")
        if isinstance(node_list, list):
            total = 0
            for node in node_list:
                if isinstance(node, dict) and node.get("type") in (2, 3, 4, 5, 6):
                    total += 1
            return total
        return None

    def _read_bestof_best(self, task_tag: str) -> Optional[Dict[str, Any]]:
        op_root = self._resolve_op_root()
        bestof_dir = os.path.join(op_root, task_tag, "bestof")
        if not os.path.isdir(bestof_dir):
            return None
        try:
            files = [f for f in os.listdir(bestof_dir) if f.endswith(".json")]
        except Exception:
            return None
        best_error = None
        best_ops = None
        for name in files:
            path = os.path.join(bestof_dir, name)
            try:
                payload = self._read_json(path)
            except Exception:
                continue
            error = self._normalize_error(payload.get("optimization_error"))
            if error is None:
                continue
            ops = self._count_ops_from_candidate(payload)
            if (
                best_error is None
                or error < best_error
                or (
                    error == best_error
                    and (best_ops is None or (ops is not None and ops < best_ops))
                )
            ):
                best_error = error
                best_ops = ops
        if best_error is None:
            return None
        return {"error": best_error, "ops": best_ops}

    @staticmethod
    def _normalize_error(value: Any) -> Optional[float]:
        if value is None:
            return None
        try:
            num = float(value)
        except (TypeError, ValueError):
            return None
        if not math.isfinite(num):
            return None
        return num

    @staticmethod
    def _normalize_ops(value: Any) -> Optional[int]:
        if value is None:
            return None
        try:
            return int(value)
        except (TypeError, ValueError):
            return None

    @staticmethod
    def _collect_task_tags(state: AgentState) -> List[str]:
        tags = []
        for tag in state.pending_callbacks:
            if tag:
                tags.append(tag)
        for run in state.runs:
            if run.task_tag:
                tags.append(run.task_tag)
        for piece in state.piece_statuses.values():
            if piece.task_tag:
                tags.append(piece.task_tag)
        return list(dict.fromkeys(tags))

    @staticmethod
    def _piece_id_by_task_tag(state: AgentState) -> Dict[str, str]:
        mapping: Dict[str, str] = {}
        for piece_id, piece in state.piece_statuses.items():
            task_tag = str(piece.task_tag or "").strip()
            if task_tag:
                mapping[task_tag] = str(piece_id)
        return mapping

    def build_summary(self, exp_id: str, state: AgentState) -> Dict[str, Any]:
        from agent.event_log import EventLog

        now = int(time.time())
        event_log = EventLog(repo_root=self.repo_root, experiments_dir=self.experiments_dir)
        latest_event = event_log.read_latest_event(exp_id)
        last_event_at = latest_event.get("timestamp") if latest_event else None

        last_task_tag = None
        if latest_event and latest_event.get("task_tag"):
            last_task_tag = latest_event.get("task_tag")
        elif state.pending_callbacks:
            last_task_tag = state.pending_callbacks[-1]
        elif state.runs:
            last_task_tag = state.runs[-1].task_tag
        elif state.piece_statuses:
            last_task_tag = next(iter(state.piece_statuses.values())).task_tag

        best_error = None
        best_ops = None
        best_piece_error = None
        best_piece_ops = None
        best_piece_id = None
        worst_piece_error = None
        worst_piece_ops = None
        worst_piece_id = None
        piece_best: Dict[str, Dict[str, Any]] = {}
        task_to_piece = self._piece_id_by_task_tag(state)

        def consider_candidate(error_value: Any, ops_value: Any) -> None:
            nonlocal best_error, best_ops
            error_num = self._normalize_error(error_value)
            if error_num is None:
                return
            if best_error is None or error_num < best_error:
                best_error = error_num
                best_ops = self._normalize_ops(ops_value)

        def consider_piece(piece_id: str, error_value: Any, ops_value: Any) -> None:
            error_num = self._normalize_error(error_value)
            if error_num is None:
                return
            current = piece_best.get(piece_id)
            if current is None or error_num < float(current.get("error")):
                piece_best[piece_id] = {
                    "error": error_num,
                    "ops": self._normalize_ops(ops_value),
                }

        if state.best_solution:
            best_candidate = state.best_solution.get("best_candidate", {})
            consider_candidate(
                best_candidate.get("optimization_error"),
                best_candidate.get("ops"),
            )

        for piece in state.piece_statuses.values():
            if piece.result:
                consider_candidate(
                    piece.result.get("optimization_error"),
                    piece.result.get("ops"),
                )
                consider_piece(
                    str(piece.piece_id),
                    piece.result.get("optimization_error"),
                    piece.result.get("ops"),
                )

        for task_tag in self._collect_task_tags(state):
            bestof_best = self._read_bestof_best(task_tag)
            if bestof_best:
                consider_candidate(bestof_best.get("error"), bestof_best.get("ops"))
                piece_id = task_to_piece.get(task_tag)
                if piece_id is not None:
                    consider_piece(piece_id, bestof_best.get("error"), bestof_best.get("ops"))
            status_data = self._read_status_json(task_tag)
            if not status_data:
                continue
            best_summary = status_data.get("best_summary") or {}
            consider_candidate(best_summary.get("error"), best_summary.get("ops"))
            piece_id = task_to_piece.get(task_tag)
            if piece_id is not None:
                consider_piece(piece_id, best_summary.get("error"), best_summary.get("ops"))

        if piece_best:
            ranked = sorted(piece_best.items(), key=lambda item: float(item[1]["error"]))
            best_piece_id, best_piece_payload = ranked[0]
            worst_piece_id, worst_piece_payload = ranked[-1]
            best_piece_error = float(best_piece_payload["error"])
            best_piece_ops = self._normalize_ops(best_piece_payload.get("ops"))
            worst_piece_error = float(worst_piece_payload["error"])
            worst_piece_ops = self._normalize_ops(worst_piece_payload.get("ops"))
            # Global objective in piecewise mode: minimize worst-piece error.
            best_error = worst_piece_error
            best_ops = worst_piece_ops

        last_progress_at = None
        if last_task_tag:
            status_data = self._read_status_json(last_task_tag)
            if status_data:
                last_progress_at = status_data.get("updated_at")

        stale_for_s = None
        if last_progress_at:
            try:
                stale_for_s = max(0, now - int(last_progress_at))
            except (TypeError, ValueError):
                stale_for_s = None

        return {
            "exp_id": exp_id,
            "phase": state.phase,
            "best_error": best_error,
            "best_ops": best_ops,
            "best_piece_error": best_piece_error,
            "best_piece_ops": best_piece_ops,
            "best_piece_id": best_piece_id,
            "worst_piece_error": worst_piece_error,
            "worst_piece_ops": worst_piece_ops,
            "worst_piece_id": worst_piece_id,
            "piece_best_errors": {pid: payload.get("error") for pid, payload in piece_best.items()},
            "objective_metric": "worst_piece_error" if piece_best else "best_error",
            "last_task_tag": last_task_tag,
            "last_event_at": last_event_at,
            "last_progress_at": last_progress_at,
            "stale_for_s": stale_for_s,
            "retry_count": state.retry_state.get("retry_count", 0),
            "updated_at": now,
            "pending_callbacks": state.pending_callbacks,
        }

    def refresh_summary(self, exp_id: str, state: Optional[AgentState] = None) -> Dict[str, Any]:
        if state is None:
            state = self.load(exp_id)
        summary = self.build_summary(exp_id, state)
        try:
            self._safe_write_json(self._summary_path(exp_id), summary)
        except Exception:
            pass
        return summary

    def load_final(self, exp_id: str) -> Optional[Dict[str, Any]]:
        path = self._final_path(exp_id)
        if not os.path.exists(path):
            return None
        try:
            return self._read_json(path)
        except Exception:
            return None

    def save_final(self, exp_id: str, payload: Dict[str, Any]) -> str:
        exp_dir = self._exp_dir(exp_id)
        os.makedirs(exp_dir, exist_ok=True)
        final_path = self._final_path(exp_id)
        self._safe_write_json(final_path, payload)
        return final_path

    def exists(self, exp_id: str) -> bool:
        """Check if an experiment exists."""
        return os.path.exists(self._state_path(exp_id))

    def save(self, exp_id: str, state: AgentState) -> str:
        """Save agent state to disk.

        Args:
            exp_id: Experiment identifier
            state: AgentState to persist

        Returns:
            Path to the state file
        """
        exp_dir = self._exp_dir(exp_id)
        os.makedirs(exp_dir, exist_ok=True)

        state.updated_at = int(time.time())
        state.exp_id = exp_id

        state_path = self._state_path(exp_id)
        lock_path = self._lock_path(exp_id)

        with open(lock_path, "w") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            try:
                tmp_path = f"{state_path}.tmp"
                with open(tmp_path, "w", encoding="utf-8") as f:
                    json.dump(state.to_dict(), f, indent=2, ensure_ascii=True)
                os.replace(tmp_path, state_path)
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)

        self.refresh_summary(exp_id, state)
        return state_path

    def load(self, exp_id: str) -> AgentState:
        """Load agent state from disk.

        Args:
            exp_id: Experiment identifier

        Returns:
            Reconstructed AgentState

        Raises:
            FileNotFoundError: If experiment doesn't exist
        """
        state_path = self._state_path(exp_id)
        lock_path = self._lock_path(exp_id)

        if not os.path.exists(state_path):
            raise FileNotFoundError(f"Experiment {exp_id} not found at {state_path}")

        os.makedirs(os.path.dirname(lock_path), exist_ok=True)
        with open(lock_path, "w") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_SH)
            try:
                with open(state_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)

        return AgentState.from_dict(data)

    def save_config(self, exp_id: str, config: Dict[str, Any]) -> str:
        """Save initial experiment configuration (immutable).

        Args:
            exp_id: Experiment identifier
            config: Configuration dict (user request, parameters, etc.)

        Returns:
            Path to config file
        """
        exp_dir = self._exp_dir(exp_id)
        os.makedirs(exp_dir, exist_ok=True)

        config_path = self._config_path(exp_id)
        if os.path.exists(config_path):
            return config_path

        defaults = {
            "timeout_s": 172800,
            "retry_backoff_s": 600,
            "max_retries": 2,
        }
        payload = defaults.copy()
        payload.update(config)

        tmp_path = f"{config_path}.tmp"
        with open(tmp_path, "w", encoding="utf-8") as f:
            json.dump(payload, f, indent=2, ensure_ascii=True)
        os.replace(tmp_path, config_path)
        return config_path

    def load_config(self, exp_id: str) -> Dict[str, Any]:
        """Load experiment configuration.

        Args:
            exp_id: Experiment identifier

        Returns:
            Configuration dict
        """
        config_path = self._config_path(exp_id)
        if not os.path.exists(config_path):
            return {}
        with open(config_path, "r", encoding="utf-8") as f:
            return json.load(f)

    @staticmethod
    def _normalize_supervision_payload(payload: Any) -> Dict[str, Any]:
        tasks = {}
        global_state = {}
        if isinstance(payload, dict) and isinstance(payload.get("tasks"), dict):
            tasks = payload.get("tasks", {})
        if isinstance(payload, dict) and isinstance(payload.get("global"), dict):
            global_state = payload.get("global", {})
        return {
            "version": 1,
            "tasks": tasks,
            "global": global_state,
            "updated_at": int(time.time()),
        }

    def load_supervision(self, exp_id: str) -> Dict[str, Any]:
        """Load poller supervision state for an experiment."""
        path = self._supervision_path(exp_id)
        lock_path = self._supervision_lock_path(exp_id)
        if not os.path.exists(path):
            return self._normalize_supervision_payload({})

        os.makedirs(os.path.dirname(lock_path), exist_ok=True)
        with open(lock_path, "w") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_SH)
            try:
                with open(path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception:
                data = {}
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
        return self._normalize_supervision_payload(data)

    def save_supervision(self, exp_id: str, payload: Dict[str, Any]) -> str:
        """Persist poller supervision state for an experiment."""
        exp_dir = self._exp_dir(exp_id)
        os.makedirs(exp_dir, exist_ok=True)
        path = self._supervision_path(exp_id)
        lock_path = self._supervision_lock_path(exp_id)
        normalized = self._normalize_supervision_payload(payload)

        with open(lock_path, "w") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            try:
                self._safe_write_json(path, normalized)
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)
        return path

    def list_experiments(self, status_filter: Optional[str] = None) -> List[Dict[str, Any]]:
        """List all experiments with optional status filter.

        Args:
            status_filter: Optional filter by phase (init, planning, running, evaluating, finalized)

        Returns:
            List of experiment summaries
        """
        experiments = []
        if not os.path.isdir(self.experiments_dir):
            return experiments

        for name in sorted(os.listdir(self.experiments_dir)):
            exp_dir = os.path.join(self.experiments_dir, name)
            if not os.path.isdir(exp_dir):
                continue

            state_path = os.path.join(exp_dir, "state.json")
            if not os.path.exists(state_path):
                continue

            try:
                with open(state_path, "r", encoding="utf-8") as f:
                    data = json.load(f)
            except Exception:
                continue

            phase = data.get("phase", "unknown")
            if status_filter is not None and phase != status_filter:
                continue

            experiments.append({
                "exp_id": name,
                "session_id": data.get("session_id", ""),
                "phase": phase,
                "user_request": data.get("user_request"),
                "pending_callbacks": len(data.get("pending_callbacks", [])),
                "parallel_mode": data.get("parallel_mode", False),
                "created_at": data.get("created_at"),
                "updated_at": data.get("updated_at"),
            })

        return experiments

    def delete(self, exp_id: str) -> bool:
        """Delete an experiment and all its data.

        Args:
            exp_id: Experiment identifier

        Returns:
            True if deleted, False if not found
        """
        import shutil

        exp_dir = self._exp_dir(exp_id)
        if not os.path.isdir(exp_dir):
            return False

        shutil.rmtree(exp_dir)
        return True

    def update_piece_status(
        self,
        exp_id: str,
        piece_id: str,
        status: str,
        result: Optional[Dict[str, Any]] = None,
        error: Optional[str] = None,
        task_tag: Optional[str] = None,
        event_id: Optional[str] = None,
        run_status: Optional[str] = None,
        update_pending: bool = True,
    ) -> AgentState:
        """Atomically update a piece's status.

        Thread-safe for concurrent callback handling.

        Args:
            exp_id: Experiment identifier
            piece_id: Piece identifier
            status: New status (pending, running, completed, failed)
            result: Optional result data
            error: Optional error message
            task_tag: Task tag to update pending callbacks and run records
            event_id: Event ID to store in state
            run_status: Optional run status to update
            update_pending: Whether to remove task_tag from pending_callbacks

        Returns:
            Updated AgentState
        """
        lock_path = self._lock_path(exp_id)
        os.makedirs(os.path.dirname(lock_path), exist_ok=True)

        state_path = self._state_path(exp_id)
        updated_state = None
        with open(lock_path, "w") as lock_file:
            fcntl.flock(lock_file.fileno(), fcntl.LOCK_EX)
            try:
                if not os.path.exists(state_path):
                    raise FileNotFoundError(f"Experiment {exp_id} not found at {state_path}")

                with open(state_path, "r", encoding="utf-8") as handle:
                    state = AgentState.from_dict(json.load(handle))

                if piece_id in state.piece_statuses:
                    state.piece_statuses[piece_id].status = status
                    if result is not None:
                        state.piece_statuses[piece_id].result = result
                    if error is not None:
                        state.piece_statuses[piece_id].error = error

                if update_pending and task_tag and task_tag in state.pending_callbacks:
                    state.pending_callbacks.remove(task_tag)

                if run_status and task_tag:
                    for run in state.runs:
                        if run.task_tag == task_tag:
                            run.status = run_status
                            break

                if event_id:
                    state.last_event_id = event_id

                state.updated_at = int(time.time())

                tmp_path = f"{state_path}.tmp"
                with open(tmp_path, "w", encoding="utf-8") as f:
                    json.dump(state.to_dict(), f, indent=2, ensure_ascii=True)
                os.replace(tmp_path, state_path)

                updated_state = state
            finally:
                fcntl.flock(lock_file.fileno(), fcntl.LOCK_UN)

        if updated_state is None:
            raise RuntimeError("Failed to update piece status")

        self.refresh_summary(exp_id, updated_state)
        return updated_state
