import contextlib, hashlib, io, json, os, re, sys, textwrap
from typing import List, Dict
from functools import partial

# ─────────────────────────── utilities ────────────────────────────
def run_llm(system: str, messages, model_id="gpt-4o"):
    from AgentOccam.AgentOccam import (
        MODEL_FAMILIES,
        CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP
    )
    fam = [m for m in MODEL_FAMILIES if m in model_id][0]
    call_fn = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[fam], model_id=model_id)
    return call_fn(system_prompt=system, messages=messages)

_JSON_RE = re.compile(r"\{.*?\}|\[.*?\]", re.S)

def _grab_json(text: str):
    m = _JSON_RE.search(text)
    if not m:
        return None
    try:
        return json.loads(m.group(0))
    except Exception:
        return None

def get_extractor_prompt(user_goal: str,
                         sample_page_text: str,
                         model_id: str = "gpt-4o") -> str:
    """
    Ask an LLM to design an extraction prompt *template* given:
      • user_goal  (high-level intent)
      • sample_page_text  (one snapshot)
    The LLM should emit a self-contained prompt that:
      1. Lists the JSON keys to output
      2. Gives tips on locating the info
      3. Ends with:  `[PAGE]`  placeholder
    """
    SYSTEM = (
        "You are an expert prompt engineer.\n"
        "Design a SINGLE prompt that, when shown together with a web-page "
        "text accessibility tree, makes another LLM extract and return ONLY a JSON object or a list of JSON object"
        "containing the fields that satisfy the user's goal.\n"
        "You need to specify 1) what information to be extracted, 2) the output format including JSON keys, 3) (Optional) any tips you think helpful for locating the information in the webpage. \n"
        "No code fences, no markdown."
    )
    USER = textwrap.dedent(f"""
        [USER GOAL]
        {user_goal}

        [SAMPLE PAGE]
        {sample_page_text}
    """)
    messages = [{"role": "user",
                 "content":[{"type":"text","text": USER}]}]
    raw_prompt = run_llm(SYSTEM, messages, model_id=model_id).strip()

    return raw_prompt

def run_prompt_extractor(extractor_prompt: str,
                         page_text: str,
                         model_id: str = "gpt-4o",
                         retries: int = 3,
                         retry_delay: float = 1.0):
    """
    Feed the *pre-built* extractor_prompt plus concrete page_text to the LLM.
    Returns parsed JSON (dict/list) or None if extraction fails after retries.
    """
    for attempt in range(1, retries + 1):
        USER = textwrap.dedent(f"""
            [REQUIREMENT]
            {extractor_prompt}

            [SAMPLE PAGE]
            {page_text}
        """)
        messages = [{"role":"user",
                     "content":[{"type":"text","text": USER}]}]

        raw = run_llm(
            system="You are a data extractor.  Your task is to extract the needed information from a webpage. "
                   "You will be provided with the webpage text accessibility tree as well as a requirement of what to extract and output format. Strictly follow it."
                   "Return ONLY the requested information or 'null' if absent.",
            messages=messages,
            model_id=model_id)

        res = _grab_json(raw)
        if res is not None and res not in ("null", [], {}):
            return res

        print(f"   ↺ retry {attempt}/{retries} – no JSON extracted")

    return None


def _read_obs(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        txt = f.read()
    parts = txt.split("FULL OBSERVATION TEXT:", 1)
    return (parts[1] if len(parts) == 2 else txt).strip()


def _hash(text: str) -> str:
    return hashlib.sha1(re.sub(r"\s+", "", text).encode()).hexdigest()


def _extract_code_block(raw: str) -> str | None:
    m = re.search(r"```python\s*(.*?)\s*```", raw, re.S)
    return m.group(1) if m else None

def _flat_observation(raw: str) -> str:
    return ' '.join(raw.split())
# ─────────────────── 1.  obtain extraction code ──────────────────
def get_extractor_code(user_goal: str, example_page: str,
                       model_id: str = "gpt-4o", retries=3) -> str:
    system = (
      "You are a web-page extraction agent.\n"
      "INPUTS you will receive:\n"
      "  • USER_GOAL   – natural language description of desired info\n"
      "  • PAGE        – raw accessibility-tree text\n\n"
      "Return **ONLY one fenced Python block** that:\n"
      " 1) defines a function  extract(page:str)->list[dict]\n"
      " 2) extracts the goal-specific info from *the given page text*;\n"
      " 3) assigns the result to variable  answer .\n"
      "Do NOT print anything else."
    )
    prompt = f"[USER_GOAL]\n{user_goal}\n\n[PAGE]\n{example_page}"
    messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]

    for attempt in range(1, retries + 1):
        raw = run_llm(system, messages, model_id=model_id)
        code = _extract_code_block(raw)
        if code and "def extract(" in code and "answer" in code:
            return code
        print(f"↺ retry {attempt}/{retries} – extractor code invalid")
    raise RuntimeError("Failed to obtain valid extraction code from LLM.")


# ─────────────────── 2.  run code safely on a page ────────────────
def run_extractor(code: str, page_text: str) -> Dict | List | None:
    sandbox_globals, sandbox_locals = {}, {"page": page_text}
    stdout = io.StringIO()
    code_wrapped = textwrap.dedent(code) + "\n\nanswer = extract(page)"
    try:
        with contextlib.redirect_stdout(stdout):
            exec(code_wrapped, sandbox_globals, sandbox_locals)
        return sandbox_locals.get("answer", None)
    except Exception as exc:
        print("⚠️ extractor runtime error:", exc)
        return None


# ─────────────────── 3.  driver with deduplication ───────────────
# def run_directory(task_dir: str, user_goal: str,
#                   model_id="gpt-4o") -> List[Dict]:
#     try:
#         txt_files = sorted([f for f in os.listdir(task_dir)
#                             if f.startswith("obs_step") and f.endswith(".txt")],
#                         key=lambda x: int(re.findall(r"\d+", x)[0]))
#         if not txt_files:
#             raise FileNotFoundError("No observation dumps found.")

#         hashes, unique = set(), []
#         for f in txt_files:
#             txt = _read_obs(os.path.join(task_dir, f))
#             h = _hash(txt)
#             if h not in hashes:
#                 hashes.add(h); unique.append((f, txt))
#             else:
#                 print(f"🌀 duplicate skipped: {f}")

#         # 1) get extractor code once
#         extractor_code = get_extractor_code(user_goal, _flat_observation(unique[-1][1]), model_id)
#         print("🔧 extractor code obtained")
#         print(extractor_code)
#         # 2) apply to every unique page
#         results = []
#         for f, txt in unique:
#             out = run_extractor(extractor_code, _flat_observation(txt))
#             if out:
#                 results+=out; print(f"✔ extracted from {f}")
#             else:
#                 print(f"✖ nothing in {f}")
#         print(f"🏁 unique pages: {len(unique)}  extracted: {len(results)}")

#         if len(results) == 0:
#             print("⚠️ Nothing extracted, switching to text extraction.")
#             extractor_prompt = get_extractor_prompt(user_goal, unique[-1][1], model_id)
#             print("🔧 extractor prompt obtained")
#             print(extractor_prompt)
#             for f, txt in unique:
#                 out = run_prompt_extractor(extractor_prompt, txt)
#                 if out:
#                     results+=out; print(f"✔ extracted from {f}")
#                 else:
#                     print(f"✖ nothing in {f}")

#         print(f"⚙️ Final extraction results: \n{results}")
#         return results
#     except:
#         return ["No observation is dumped."]

def run_directory(task_dir: str, user_goal: str,
                  model_id="gpt-4o") -> List[Dict]:
    try:
        txt_files = sorted([f for f in os.listdir(task_dir)
                            if f.startswith("obs_step") and f.endswith(".txt")],
                        key=lambda x: int(re.findall(r"\d+", x)[0]))
        if not txt_files:
            raise FileNotFoundError("No observation dumps found.")

        hashes, unique = set(), []
        for f in txt_files:
            txt = _read_obs(os.path.join(task_dir, f))
            h = _hash(txt)
            if h not in hashes:
                hashes.add(h); unique.append((f, txt))
            else:
                print(f"🌀 duplicate skipped: {f}")

        results = []
        extractor_prompt = get_extractor_prompt(user_goal, unique[-1][1], model_id)
        print("🔧 extractor prompt obtained")
        print(extractor_prompt)
        for f, txt in unique:
            out = run_prompt_extractor(extractor_prompt, txt)
            if out:
                results+=out; print(f"✔ extracted from {f}")
            else:
                print(f"✖ nothing in {f}")

        print(f"⚙️ Final extraction results: \n{results}")
        return results
    except:
        return ["No observation is dumped."]