from __future__ import annotations

import ast
import json
import re
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict


class ActionType(str, Enum):
    WRITE_SPEC = "WRITE_SPEC"
    UPDATE_SPEC = "UPDATE_SPEC"
    STOP_RUN = "STOP_RUN"
    VERIFY = "VERIFY"
    TRADITIONAL = "TRADITIONAL"
    CODEGEN = "CODEGEN"
    FINALIZE = "FINALIZE"

    SUBMIT_SEARCH_ASYNC = "SUBMIT_SEARCH_ASYNC"
    CHANGE_STRATEGY = "CHANGE_STRATEGY"
    SUGGEST_STRATEGY = "SUGGEST_STRATEGY"


ACTION_ALIASES = {
    "SUBMIT_SEARCH": "SUBMIT_SEARCH_ASYNC",
    "SPLIT_PIECE": "CHANGE_STRATEGY",
    "ROUTE_TO_BASELINE": "TRADITIONAL",
    "EMIT_CODE": "CODEGEN",
}


@dataclass
class Action:
    type: ActionType
    args: Dict[str, Any]

    def to_dict(self) -> Dict[str, Any]:
        return {"action": self.type.value, "args": self.args}


def _extract_json(text: str) -> Dict[str, Any]:
    text = text.strip()
    candidates = _json_candidates(text)
    for candidate in candidates:
        obj = _parse_candidate(candidate)
        if isinstance(obj, dict):
            return obj
    raise ValueError("No JSON object found in action response.")


def _json_candidates(text: str) -> list[str]:
    candidates: list[str] = []
    if text:
        candidates.append(text)

    # Prefer fenced content when present.
    for match in re.finditer(r"```(?:json)?\s*([\s\S]*?)```", text, flags=re.IGNORECASE):
        chunk = match.group(1).strip()
        if chunk:
            candidates.append(chunk)

    # Add balanced top-level {...} spans from raw text.
    start: int | None = None
    depth = 0
    in_string = False
    escape = False
    for idx, ch in enumerate(text):
        if in_string:
            if escape:
                escape = False
                continue
            if ch == "\\":
                escape = True
                continue
            if ch == '"':
                in_string = False
            continue
        if ch == '"':
            in_string = True
            continue
        if ch == "{":
            if depth == 0:
                start = idx
            depth += 1
            continue
        if ch == "}":
            if depth <= 0:
                continue
            depth -= 1
            if depth == 0 and start is not None:
                chunk = text[start : idx + 1].strip()
                if chunk:
                    candidates.append(chunk)
                start = None
    return candidates


def _normalize_json_like(text: str) -> str:
    normalized = text.lstrip("\ufeff")
    translation = str.maketrans(
        {
            "，": ",",
            "：": ":",
            "“": '"',
            "”": '"',
            "‘": "'",
            "’": "'",
        }
    )
    normalized = normalized.translate(translation)
    # Strip trailing commas in objects/arrays.
    while True:
        stripped = re.sub(r",\s*([}\]])", r"\1", normalized)
        if stripped == normalized:
            break
        normalized = stripped
    return normalized.strip()


def _parse_candidate(candidate: str) -> Dict[str, Any] | None:
    decoder = json.JSONDecoder()
    variants = [candidate, _normalize_json_like(candidate)]
    for variant in variants:
        if not variant:
            continue
        try:
            obj, _ = decoder.raw_decode(variant)
            if isinstance(obj, dict):
                return obj
        except json.JSONDecodeError:
            pass
        try:
            literal = ast.literal_eval(variant)
            if isinstance(literal, dict):
                return literal
        except Exception:
            pass
    return None


def parse_action(text: str) -> Action:
    data = _extract_json(text)
    action_name = data.get("action") or data.get("type")
    if not action_name:
        raise ValueError("Action JSON must include 'action' field.")
    normalized = action_name.strip().upper().replace("-", "_")
    normalized = ACTION_ALIASES.get(normalized, normalized)
    try:
        action_type = ActionType[normalized]
    except KeyError as exc:
        raise ValueError(f"Unknown action type: {action_name}") from exc

    args = data.get("args")
    if not isinstance(args, dict):
        args = {k: v for k, v in data.items() if k not in ("action", "type")}
    return Action(type=action_type, args=args)
