from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict, Optional
import json
import gymnasium as gym
from openai import OpenAI
import os
import re

@dataclass
class AgentOutput:
    action: Any
    meta: Dict[str, Any]

class BaseAgent:
    def __init__(self, agent_id: int, policy_text: str = "") -> None:
        self.agent_id = agent_id
        self.policy_text = policy_text

    def propose(self, obs: Any, info: Dict[str, Any]) -> AgentOutput:
        raise NotImplementedError

class RandomAgent(BaseAgent):
    def __init__(self, agent_id: int, action_space: gym.Space, policy_text: str = "") -> None:
        super().__init__(agent_id, policy_text=policy_text)
        self.action_space = action_space

    def propose(self, obs: Any, info: Dict[str, Any]) -> AgentOutput:
        a = self.action_space.sample()
        return AgentOutput(action=a, meta={"kind": "random"})

class DummyLLMAgent(BaseAgent):
    """
    Placeholder: random actions. Useful to validate wiring quickly.
    """
    def __init__(self, agent_id: int, action_space: gym.Space, policy_text: str = "", temperature: float = 0.7) -> None:
        super().__init__(agent_id, policy_text=policy_text)
        self.action_space = action_space
        self.temperature = temperature

    def propose(self, obs: Any, info: Dict[str, Any]) -> AgentOutput:
        a = self.action_space.sample()
        return AgentOutput(action=a, meta={"kind": "dummy_llm"})

def _space_to_schema_hint(space: gym.Space) -> str:
    """
    Lightweight hint for the model. Not a perfect JSON Schema,
    but enough to reduce invalid outputs.
    """
    if isinstance(space, gym.spaces.Dict):
        parts = []
        for k, s in space.spaces.items():
            parts.append(f"{k}: {type(s).__name__}")
        return "{ " + ", ".join(parts) + " }"
    return type(space).__name__

def _extract_json(text: str) -> Optional[Any]:
    """
    Tries hard to parse a JSON object from model output.
    """
    text = text.strip()
    # common case: pure JSON
    try:
        return json.loads(text)
    except Exception:
        pass

    # try to find the first {...} block
    start = text.find("{")
    end = text.rfind("}")
    if start != -1 and end != -1 and end > start:
        snippet = text[start : end + 1]
        try:
            return json.loads(snippet)
        except Exception:
            return None
    return None

class OpenAILLMAgent(BaseAgent):
    """
    LLM-backed agent that proposes a structured action.
    Designed for determinism and strict formatting.
    """

    def __init__(
        self,
        agent_id: int,
        action_space,
        policy_text: str = "",
        model_name: str = "gpt-4o-mini",
        temperature: float = 0.2,
        max_tokens: int = 256,
    ) -> None:
        super().__init__(agent_id, policy_text=policy_text)
        self.action_space = action_space
        self.model_name = model_name
        self.temperature = temperature
        self.max_tokens = max_tokens

        # Client reads OPENAI_API_KEY from env
        self.client = OpenAI(api_key=os.environ.get("OPENAI_API_KEY"))


    def _build_prompt(self, obs: Any, info: Dict[str, Any]) -> str:
        """
        Construct a minimal, deterministic prompt.
        """
        def _json_default(o: Any):
            # numpy arrays / scalars
            try:
                import numpy as np
                if isinstance(o, np.ndarray):
                    return {"__ndarray__": True, "shape": list(o.shape), "dtype": str(o.dtype)}
                if isinstance(o, np.generic):
                    return o.item()
            except Exception:
                pass

            # bytes
            if isinstance(o, (bytes, bytearray)):
                return {"__bytes__": True, "len": len(o)}

            # fallback
            return str(o)

        def safe_json_dumps(x: Any, max_len: int = 8000) -> str:
            s = json.dumps(x, indent=2, ensure_ascii=False, default=_json_default)
            if len(s) > max_len:
                s = s[:max_len] + "\n...<TRUNCATED>..."
            return s

        def compact_obs(obs: Any) -> Any:
            if not isinstance(obs, dict):
                return obs
            # common heavy keys in browser envs — adjust to your env’s actual keys
            drop = {"image", "screenshot", "pixels", "rgb", "vision", "raw_html"}
            out = {}
            for k, v in obs.items():
                if k in drop:
                    continue
                out[k] = v
            return out

        def summarize_axtree(obs):
            nodes = obs['axtree_object']['nodes']
            summary = []
            for node in nodes:
                # Only show the agent elements that it can actually interact with
                if node.get('bid') and node.get('role') not in ['GenericContainer', 'RootWebArea']:
                    summary.append(f"BID: {node['bid']}, Role: {node['role']}, Name: {node.get('name', '')}")
            return "\n".join(summary)

        def summarize_elements(obs: dict, max_items: int = 60) -> str:
            """
            Extract a short list of actionable elements with BIDs from obs.
            Robust to non-string fields in AX tree.
            """

            def safe_str(x) -> str:
                if x is None:
                    return ""
                if isinstance(x, str):
                    return x
                # Sometimes BrowserGym stores rich objects in name/text
                try:
                    return json.dumps(x, ensure_ascii=False)
                except Exception:
                    return str(x)

            items = []

            # 1) extra_element_properties: bid -> metadata
            eep = obs.get("extra_element_properties")
            if isinstance(eep, dict):
                for bid, meta in eep.items():
                    if not isinstance(meta, dict):
                        continue
                    text = safe_str(meta.get("text") or meta.get("name") or meta.get("aria_label") or "").strip()
                    role = safe_str(meta.get("role") or "").strip()
                    tag = safe_str(meta.get("tag") or "").strip()
                    if text:
                        items.append((str(bid), role or tag, text))

            # 2) axtree_object: node list varies by version
            ax = obs.get("axtree_object")
            if isinstance(ax, dict):
                nodes = ax.get("nodes")
                if nodes is None:
                    nodes = ax.get("tree")
                if nodes is None:
                    nodes = ax.get("axTree")
                if nodes is None:
                    nodes = []

                if isinstance(nodes, list):
                    for n in nodes:
                        if not isinstance(n, dict):
                            continue
                        bid = n.get("bid") or n.get("browsergym_id") or n.get("id")
                        raw = n.get("name") or n.get("text") or ""
                        if isinstance(raw, dict):
                            raw = raw.get("value") or raw.get("text") or str(raw)
                        name = str(raw).strip()
                        role = safe_str(n.get("role") or "").strip()
                        if bid and name:
                            items.append((str(bid), role, name))

            # Dedup and truncate
            seen = set()
            out = []
            for bid, role, text in items:
                key = (bid, text)
                if key in seen:
                    continue
                seen.add(key)
                out.append(f"- {bid} | {role} | {text}")
                if len(out) >= max_items:
                    break

            return "\n".join(out) if out else "<no elements extracted>"

        if self.policy_text:
            policy_block = "POLICY:\n" + self.policy_text
        else:
            policy_block = ""
        prompt_obs = compact_obs(obs)

        examples = "\n".join([
            '{"action":"wait"}',
            '{"action":"scroll down"}',
            '{"action":"click [123]"}',
            '{"action":"type [456] hello"}',
            '{"action":"press Enter"}',
            '{"action":"goto http://localhost:7770/"}',
        ])

        goal = ""
        if isinstance(obs, dict):
            goal = obs.get("goal", "")
        elements = summarize_elements(obs)
        last_action = obs.get("last_action", "")
        last_err = obs.get("last_action_error", "")
#        - "IMPORTANT: For this test, ONLY use safe actions: 'scroll down', 'wait', 'goto http://localhost:7770/' — DO NOT use click or type!"
#        "IMPORTANT: ALWAYS output click/type with EXACT format: 'click [123]' or 'type [456] hello' — no spaces inside brackets, no extra chars!"
#        - For click/type you MUST use a BID shown in ELEMENTS.
        #- Do NOT output 'click <words>'. Output 'click <BID...>'.
        #- If LAST ACTION ERROR is not empty, change strategy (different BID / scroll / goto).
        #        - "CRITICAL: This is the hosted WebArena demo - ONLY output these safe actions: 'wait', 'scroll down', 'scroll up', 'refresh'. DO NOT use click, type, or goto - they will be rejected!"
        # Check what roles ARE present if search isn't
        all_roles = set()
        for node in obs['axtree_object']['nodes']:
            r = node.get('role', '')
            role_str = r.get('value', r) if isinstance(r, dict) else r
            all_roles.add(str(role_str))

        #print(f"DEBUG: All roles found on page: {all_roles}")
        #print(f"DEBUG: Total nodes: {len(obs['axtree_object']['nodes'])}")
        #print("\n--- LINKS AVAILABLE ---")
        for node in obs['axtree_object']['nodes']:
            role = str(node.get('role', '')).lower()
            name = str(node.get('name', '')).lower()
            #if 'link' in role or 'button' in role:
                #print(f"BID [{node.get('bid')}]: {name}")

        prompt = f"""
        You are controlling a browser in WebArena (BrowserGym).
        Your environment expects ONE action as a SINGLE STRING command.

        RULES:
        - Output MUST be a single JSON object and NOTHING else.
        - The JSON MUST have exactly one key: "action"
        - The value of "action" MUST be a STRING command (not a dict).

        {policy_block}

        GOAL:
        {goal}

        OBSERVATION (truncated):
        {safe_json_dumps(prompt_obs)}

        ELEMENTS (use these BIDs for click/type):
        {elements}

        LAST ACTION:
        {last_action}

        LAST ACTION ERROR (if any):
        {last_err}

        VALID OUTPUT EXAMPLES (JSON ONLY):
        {examples}

        IMPORTANT:
        - Avoid repeating the exact same action more than 2 times.
        - IMPORTANT: use only these action format in '' in EXACT format:
            'noop()': Do nothing
            'click(\"BID\")': Click at an element
            'hover(\"BID\")': Hover on an element
            'fill(\"272\", \"green tea\")': Type to an element
            'press(\"BID\", \"Enter\")': Press a key comb
            'scroll(\"dir\", \"delta_y\")': Scroll up and down, e.g. 'scroll(\"down\", \"50\")'
            'tab_focus(\"index\")': focus on i-th tab
            'new_tab()': Open a new tab
            'tab_close()': Close current tab
            'go_back()': Visit the last URL
            'go_forward()': Undo go_back
            'goto(\"URL\")': Go to URL, e.g. 'goto(\"www.google.com\")'
            'STOP [Task Complete]': Completed search or other tasks. Submit for success.
        - For elem used you MUST use a BID shown in ELEMENTS.
        - Explore the page with scroll and occasional fill/press/click on BIDs that are not 272. Repeat actions if needed to find elements.
        - Search/fill/click/press BID is 272. Submit with action STOP after filling with 272. Do not click or press with 272.
        - "When you have completed the task (e.g., search submitted and results loaded), reply with: STOP [your final answer or 'Task complete']. This is required for success."
        - You MUST stay on the shopping website at http://localhost:7770/ — DO NOT goto external sites like google.com!
        - Use the site's own search bar for searching (e.g., type into the search input field).
        - For search: Find the search box BID from ELEMENTS and use 'type(\"elem\", \"text\")' then 'press Enter'.
        - Never use 'goto' unless it's within the current site (e.g., 'goto /category/...').

        Now output the next action as JSON only.
        """.strip()

        return prompt

    def propose(self, obs: Any, info: Dict[str, Any]):
        prompt = self._build_prompt(obs, info)
        _JSON_RE = re.compile(r"\{.*\}", re.DOTALL)

        def extract_json_object(text: str) -> dict:
            t = text.strip()

            # strip common code fences
            if t.startswith("```"):
                t = re.sub(r"^```(?:json)?\s*", "", t)
                t = re.sub(r"\s*```$", "", t)

            # try direct parse
            try:
                obj = json.loads(t)
                if isinstance(obj, dict):
                    return obj
            except Exception:
                pass

            # try to find the first {...} blob
            m = _JSON_RE.search(t)
            if not m:
                raise ValueError("No JSON object found in model output")
            return json.loads(m.group(0))

        def fallback_action_for_space(space):
            # Try a few common “safe” no-op/wait actions
            candidates = [
                "wait1",
                "scroll down",
            ]
            for c in candidates:
                try:
                    if hasattr(space, "contains") and space.contains(c):
                        return c
                except Exception:
                    continue
            # last resort: something stringy (many BrowserGym stacks accept a string action)
            return "wait"

        try:
            resp = self.client.responses.create(
                model=self.model_name,
                input=prompt,
                temperature=self.temperature,
                max_output_tokens=self.max_tokens,
                #response_format={"type":"json_object"},
            )

            text = resp.output_text.strip()
            print("\n[LLM RAW OUTPUT]")
            print(text[:500])
            print("[/LLM RAW OUTPUT]\n")
            parsed = extract_json_object(text)
            print("[PARSED JSON]", parsed)
            action = parsed.get("action")
            print("[ACTION FROM MODEL]", repr(action), "type=", type(action).__name__)
            if action is None:
                raise ValueError("Missing 'action' field")
            ## WebArena uses Unicode() action space; "contains" is not a reliable validator.
            # Let env.step(action) decide validity via last_action_error.
            if not isinstance(action, str):
                raise ValueError(f"Action must be a string, got {type(action)}: {action}")
            action = action.strip()

            return AgentOutput(
                action=action,
                meta={
                    "kind": "llm",
                    "model": self.model_name,
                },
            )

        except Exception as e:
            print("\n[LLM ERROR -> FALLBACK]")
            print(type(e).__name__, ":", str(e))
            fallback = fallback_action_for_space(self.action_space)
            return AgentOutput(action=fallback, meta={"kind": "fallback", "error": str(e)})