from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple


@dataclass
class RunRecord:
    run_id: str
    task_tag: str
    status: str
    summary: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "run_id": self.run_id,
            "task_tag": self.task_tag,
            "status": self.status,
            "summary": self.summary,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> RunRecord:
        return cls(
            run_id=data.get("run_id", ""),
            task_tag=data.get("task_tag", ""),
            status=data.get("status", "unknown"),
            summary=data.get("summary", {}),
        )


@dataclass
class Issue:
    level: str
    code: str
    message: str
    details: Dict[str, Any] = field(default_factory=dict)

    def to_dict(self) -> Dict[str, Any]:
        return {
            "level": self.level,
            "code": self.code,
            "message": self.message,
            "details": self.details,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> Issue:
        return cls(
            level=data.get("level", "error"),
            code=data.get("code", ""),
            message=data.get("message", ""),
            details=data.get("details", {}),
        )


@dataclass
class PieceStatus:
    """Tracks status of a single domain piece in parallel execution."""

    piece_id: str
    task_tag: str
    status: str  # "pending", "running", "completed", "failed"
    result: Optional[Dict[str, Any]] = None
    error: Optional[str] = None

    def to_dict(self) -> Dict[str, Any]:
        return {
            "piece_id": self.piece_id,
            "task_tag": self.task_tag,
            "status": self.status,
            "result": self.result,
            "error": self.error,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> PieceStatus:
        return cls(
            piece_id=data.get("piece_id", ""),
            task_tag=data.get("task_tag", ""),
            status=data.get("status", "pending"),
            result=data.get("result"),
            error=data.get("error"),
        )


@dataclass
class AgentState:
    """Agent state with full serialization support for async workflows."""

    session_id: str
    current_spec: Optional[Dict[str, Any]] = None
    runs: List[RunRecord] = field(default_factory=list)
    best_solution: Optional[Dict[str, Any]] = None
    open_issues: List[Issue] = field(default_factory=list)
    history: List[Tuple[Dict[str, Any], Dict[str, Any]]] = field(default_factory=list)

    # New fields for async workflow
    exp_id: Optional[str] = None
    phase: str = "init"  # init, planning, running, evaluating, finalized, failed, stopped
    pending_callbacks: List[str] = field(default_factory=list)  # task_tags awaiting completion
    conversation_history: List[Dict[str, str]] = field(default_factory=list)  # LLM messages

    # Parallel piece execution
    parallel_mode: bool = False
    piece_statuses: Dict[str, PieceStatus] = field(default_factory=dict)
    piece_solutions: Dict[str, Dict[str, Any]] = field(default_factory=dict)
    piece_attempts: Dict[str, int] = field(default_factory=dict)

    # User request context
    user_request: Optional[str] = None
    request_constraints: Dict[str, Any] = field(default_factory=dict)
    created_at: Optional[int] = None
    updated_at: Optional[int] = None
    last_event_id: Optional[str] = None
    retry_state: Dict[str, Any] = field(default_factory=lambda: {
        "max_retries": 0,
        "retry_count": 0,
        "last_retry_at": None,
    })

    def to_dict(self) -> Dict[str, Any]:
        return {
            "session_id": self.session_id,
            "current_spec": self.current_spec,
            "runs": [r.to_dict() for r in self.runs],
            "best_solution": self.best_solution,
            "open_issues": [i.to_dict() for i in self.open_issues],
            "history": self.history,
            "exp_id": self.exp_id,
            "phase": self.phase,
            "pending_callbacks": self.pending_callbacks,
            "conversation_history": self.conversation_history,
            "parallel_mode": self.parallel_mode,
            "piece_statuses": {k: v.to_dict() for k, v in self.piece_statuses.items()},
            "piece_solutions": self.piece_solutions,
            "piece_attempts": self.piece_attempts,
            "user_request": self.user_request,
            "request_constraints": self.request_constraints,
            "created_at": self.created_at,
            "updated_at": self.updated_at,
            "last_event_id": self.last_event_id,
            "retry_state": self.retry_state,
        }

    @classmethod
    def from_dict(cls, data: Dict[str, Any]) -> AgentState:
        runs = [RunRecord.from_dict(r) for r in data.get("runs", [])]
        open_issues = [Issue.from_dict(i) for i in data.get("open_issues", [])]
        piece_statuses = {
            k: PieceStatus.from_dict(v) for k, v in data.get("piece_statuses", {}).items()
        }

        return cls(
            session_id=data.get("session_id", ""),
            current_spec=data.get("current_spec"),
            runs=runs,
            best_solution=data.get("best_solution"),
            open_issues=open_issues,
            history=data.get("history", []),
            exp_id=data.get("exp_id"),
            phase=data.get("phase", "init"),
            pending_callbacks=data.get("pending_callbacks", []),
            conversation_history=data.get("conversation_history", []),
            parallel_mode=data.get("parallel_mode", False),
            piece_statuses=piece_statuses,
            piece_solutions=data.get("piece_solutions", {}) or {},
            piece_attempts=data.get("piece_attempts", {}) or {},
            user_request=data.get("user_request"),
            request_constraints=data.get("request_constraints", {}) or {},
            created_at=data.get("created_at"),
            updated_at=data.get("updated_at"),
            last_event_id=data.get("last_event_id"),
            retry_state=data.get("retry_state", {
                "max_retries": 0,
                "retry_count": 0,
                "last_retry_at": None,
            }),
        )

    def all_pieces_done(self) -> bool:
        """Check if all parallel pieces have completed (success or failure)."""
        if not self.piece_statuses:
            return True
        return all(p.status in ("completed", "failed") for p in self.piece_statuses.values())

    def any_piece_failed(self) -> bool:
        """Check if any parallel piece has failed."""
        return any(p.status == "failed" for p in self.piece_statuses.values())

    def get_successful_pieces(self) -> List[PieceStatus]:
        """Get list of successfully completed pieces."""
        return [p for p in self.piece_statuses.values() if p.status == "completed"]

    def get_failed_pieces(self) -> List[PieceStatus]:
        """Get list of failed pieces."""
        return [p for p in self.piece_statuses.values() if p.status == "failed"]
