import contextlib, hashlib, io, json, os, re, sys, textwrap, tqdm
from typing import List, Dict
from functools import partial
from AgentOccam.env import WebArenaEnvironmentWrapper
from concurrent.futures import ThreadPoolExecutor, as_completed
from webagents_step.utils.data_prep import *
from AgentOccam.llms.claude import call_claude, call_claude_with_messages, arrange_message_for_claude
from AgentOccam.llms.mistral import call_mistral, call_mistral_with_messages, arrange_message_for_mistral
from AgentOccam.llms.cohere import call_cohere, call_cohere_with_messages, arrange_message_for_cohere
from AgentOccam.llms.llama import call_llama, call_llama_with_messages, arrange_message_for_llama
from AgentOccam.llms.titan import call_titan, call_titan_with_messages, arrange_message_for_titan
from AgentOccam.llms.gpt import call_gpt, call_gpt_with_messages, arrange_message_for_gpt
from AgentOccam.llms.gemini import call_gemini, call_gemini_with_messages, arrange_message_for_gemini
from AgentOccam.llms.glm import call_glm, call_glm_with_messages, arrange_message_for_glm
from AgentOccam.analyzer import tri_phase_analyze

MODEL_FAMILIES = ["claude", "mistral", "cohere", "llama", "titan", "gpt", "gemini", "glm"]
CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP = {
    "claude": call_claude_with_messages,
    "mistral": call_mistral_with_messages,
    "cohere": call_cohere_with_messages,
    "llama": call_llama_with_messages,
    "titan": call_titan_with_messages,
    "gpt": call_gpt_with_messages,
    "gemini": call_gemini_with_messages,
    "glm": call_glm_with_messages,
}



# ─────────────────────────── observation ─────────────────────────────
def get_observation_from_url(
    url: str,
    prune: bool = True,
    objective: str = "",
):
    pseudo_config = {"start_url": url, "intent": objective}
    pseudo_global_config = DotDict({"env":{"prune": prune}})
    
    env = WebArenaEnvironmentWrapper(config_file=pseudo_config,
                                         max_browser_rows=500,
                                         max_steps=50,
                                         slow_mo=1,
                                         observation_type="accessibility_tree",
                                         current_viewport_only=False,
                                         viewport_size={"width": 1920, "height": 1080},
                                         headless=True,
                                         global_config=pseudo_global_config,
                                         evaluate_at_end = False)
    observation = env.observation()
    env.close()
    return observation
# ─────────────────────────── utilities ────────────────────────────
def run_llm(system: str, messages, model_id="glm-4.5-air-fp8"):
    fam = [m for m in MODEL_FAMILIES if m in model_id][0]
    call_fn = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[fam], model_id=model_id)
    return call_fn(system_prompt=system, messages=messages)

_JSON_RE = re.compile(r"\{.*?\}|\[.*?\]", re.S)

def _grab_json(text: str):
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        # If direct parsing fails, then try regex approach
        m = _JSON_RE.search(text)
        if not m:
            return None
        try:
            return json.loads(m.group(0))
        except Exception:
            return None

def deduplicate_with_order(extraction_results: List[Dict]) -> List[Dict]:
    seen = set()
    unique_results = []
    for item in extraction_results:
        t = tuple(item.items())
        if t not in seen:
            seen.add(t)
            unique_results.append(item)
    return unique_results

def _flat_observation(raw: str) -> str:
    return ' '.join(raw.split())
# ─────────────────── 1.  obtain extraction code ──────────────────
def get_extractor_code(user_goal: str, example_page: str,
                       model_id: str = "glm-4.5-air-fp8", retries=3) -> str:
    system = (
      "You are a web-page extraction agent.\n"
      "INPUTS you will receive:\n"
      "  • USER_GOAL   – natural language description of desired info\n"
      "  • PAGE        – raw accessibility-tree text\n\n"
      "Return **ONLY one fenced Python block** that:\n"
      " 1) defines a function  extract(page:str)->list[dict]\n"
      " 2) extracts the goal-specific info from *the given page text*;\n"
      " 3) assigns the result to variable  answer .\n"
      "Do NOT print anything else."
    )
    prompt = f"[USER_GOAL]\n{user_goal}\n\n[PAGE]\n{example_page}"
    messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]

    for attempt in range(1, retries + 1):
        raw = run_llm(system, messages, model_id=model_id)
        code = _extract_code_block(raw)
        if code and "def extract(" in code and "answer" in code:
            return code
        print(f"↺ retry {attempt}/{retries} – extractor code invalid")
    raise RuntimeError("Failed to obtain valid extraction code from LLM.")


# ─────────────────── 2.  run code safely on a page ────────────────
def run_extractor(code: str, page_text: str) -> Dict | List | None:
    sandbox_globals, sandbox_locals = {}, {"page": page_text}
    stdout = io.StringIO()
    code_wrapped = textwrap.dedent(code) + "\n\nanswer = extract(page)"
    try:
        with contextlib.redirect_stdout(stdout):
            exec(code_wrapped, sandbox_globals, sandbox_locals)
        return sandbox_locals.get("answer", None)
    except Exception as exc:
        print("⚠️ extractor runtime error:", exc)
        return None

def get_extractor_prompt(user_goal: str,
                         sample_page_text: str,
                         model_id: str = "glm-4.5-air-fp8") -> str:
    """
    Ask an LLM to design an extraction prompt *template* given:
      • user_goal  (high-level intent)
      • sample_page_text  (one snapshot)
    The LLM should emit a self-contained prompt that:
      1. Lists the JSON keys to output
      2. Gives tips on locating the info
      3. Ends with:  `[PAGE]`  placeholder
    """
    SYSTEM = (
        "You are an expert prompt engineer.\n"
        "Design a SINGLE prompt that, when shown together with a web-page "
        "text accessibility tree, makes another LLM extract and return ONLY a list of JSON object"
        "containing the fields that satisfy the user goal.\n"
        "Only extract the information specified in the user goal. Make sure each extracted entry also has one identifier field (add only one if there is no such key specified in user goal) that will helps accurate deduplication in the later stage.\n"
        # "Besides the information specified in the user goal, you can also include other highly relevant data entries for more comprehensive extraction result. However, do not include any web page element id (e.g. [1315]) or other unnecessary page setting information in the extraction result."
        "You need to specify 1) what information to be extracted, 2) what keys should be used for each JSON object in extracted list, 3) one simple example of the extracted JSON list.\n"
        # "Note that the page can also be irrelevant to the user goal. Therefore, the prompt should also let the LLM judge whether the page contains the needed information first. If it is not relevant, return an empty list or object.\n"
        "Make your prompt concise and only include these necessary infromation.\n"
    )
    USER = textwrap.dedent(f"""
        [USER GOAL]
        {user_goal}

        [SAMPLE PAGE]
        {sample_page_text}
    """)
    messages = [{"role": "user",
                 "content":[{"type":"text","text": USER}]}]
    raw_prompt = run_llm(SYSTEM, messages, model_id=model_id).strip()

    return raw_prompt

def run_prompt_extractor(extractor_prompt: str,
                         page_text: str,
                         model_id: str = "glm-4.5-air-fp8",
                         retries: int = 3,
                         retry_delay: float = 1.0):
    """
    Feed the *pre-built* extractor_prompt plus concrete page_text to the LLM.
    Returns parsed JSON (dict/list) or None if extraction fails after retries.
    """
    for attempt in range(1, retries + 1):
        USER = textwrap.dedent(f"""
            [REQUIREMENT]
            {extractor_prompt}

            [SAMPLE PAGE]
            {page_text}
        """)
        messages = [{"role":"user",
                     "content":[{"type":"text","text": USER}]}]

        # print(USER)
        raw = run_llm(
            system="You are a data extractor.  Your task is to extract the needed information from a webpage. "
                   "You will be provided with the webpage text accessibility tree as well as a requirement of what to extract and output format. Strictly follow it."
                   "Return ALL the requested information in the page or 'null' if some data entries are absent.",
            messages=messages,
            model_id=model_id)

        import ipdb
        # ipdb.set_trace()

        res = _grab_json(raw)
        if res is not None and res not in ("null", [], {}):
            return res

        print(f"   ↺ retry {attempt}/{retries} – no JSON extracted")
        print("Raw output was:\n", raw)

    return None

def _extract_code_block(raw: str) -> str | None:
    m = re.search(r"```python\s*(.*?)\s*```", raw, re.S)
    return m.group(1) if m else None

def _read_obs(path: str) -> str:
    with open(path, "r", encoding="utf-8") as f:
        txt = f.read()
    parts = txt.split("FULL OBSERVATION TEXT:", 1)
    return (parts[1] if len(parts) == 2 else txt).strip()


def _hash(text: str) -> str:
    return hashlib.sha1(re.sub(r"\s+", "", text).encode()).hexdigest()

# ─────────────────── 2.  run code safely on a page ────────────────
def run_extractor(code: str, page_text: str) -> Dict | List | None:
    sandbox_globals, sandbox_locals = {}, {"page": page_text}
    stdout = io.StringIO()
    code_wrapped = textwrap.dedent(code) + "\n\nanswer = extract(page)"
    try:
        with contextlib.redirect_stdout(stdout):
            exec(code_wrapped, sandbox_globals, sandbox_locals)
        return sandbox_locals.get("answer", None)
    except Exception as exc:
        print("⚠️ extractor runtime error:", exc)
        return None

# def run_directory(task_dir: str, user_goal: str,
#                   model_id="glm-4.5-air-fp8") -> List[Dict]:
#     try:
#         txt_files = sorted([f for f in os.listdir(task_dir)
#                             if f.startswith("obs_step") and f.endswith(".txt")],
#                         key=lambda x: int(re.findall(r"\d+", x)[0]))
#         if not txt_files:
#             raise FileNotFoundError("No observation dumps found.")

#         hashes, unique = set(), []
#         for f in txt_files:
#             txt = _read_obs(os.path.join(task_dir, f))
#             h = _hash(txt)
#             if h not in hashes:
#                 hashes.add(h); unique.append((f, txt))
#             else:
#                 print(f"🌀 duplicate skipped: {f}")

#         results = []
#         extractor_prompt = get_extractor_prompt(user_goal, unique[-1][1], model_id)
#         print("🔧 extractor prompt obtained")
#         print(extractor_prompt)
#         for f, txt in unique:
#             out = run_prompt_extractor(extractor_prompt, txt)
#             if out:
#                 results+=out; print(f"✔ extracted from {f}")
#             else:
#                 print(f"✖ nothing in {f}")

#         print(f"⚙️ Final extraction results: \n{results}")
#         return results
#     except:
#         return ["No observation is dumped."]

def run_second_stage(record_dir: str, task_id: str, model_id: str = "glm-4.5-air-fp8") -> Dict:

    # 1. Find the record file for the task
    record_file = os.path.join(record_dir, task_id + ".json")
    if not os.path.exists(record_file):
        print(f"Record file not found: {record_file}")
        error_msg = f"Record file not found: {record_file}"
        return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}

    # 2. Load the record file (assume JSON)
    with open(record_file, "r", encoding="utf-8") as f:
        record = json.load(f)

    # 3. Extract navigation objective and step summaries
    # Assume record has keys: 'user_goal' (objective), 'steps' (list of dicts)
    user_goal = record.get("nav_obj")
    steps = record.get("trajectory", [])
    if not user_goal or not steps:
        print("Record missing user_goal or steps.")
        error_msg = "Record missing user_goal or steps."
        return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}
    
    # Each step: {"step": int, "reason": str, "action": str, "observation_summary": str, "observation_path": str}
    step_summaries = []
    for s in steps:
        step_num = steps.index(s)
        reason = s.get("reason", "")
        action = s.get("action", "")
        obs_sum = s.get("observation_description", "")
        step_summaries.append({
            "step": step_num,
            "reason": reason,
            "action": action,
            "observation_summary": obs_sum
        })

    # 4. First LLM call: judge relevant steps based on summaries
    # SYSTEM = (
    #     "You are a judge agent in a web navigation and information seeking task. "
    #     "Given a navigation objective (which includes the information to be found in the web environment) and a list of web navigation agent interaction history (with reason, action, and observation summary), "
    #     "select the step numbers that their observations are most likely to contain the information specified in the objective. "
    #     "Analyze each step in one or two sentences. After this, return a JSON list of step numbers (e.g., [2, 5, 7]) that you believe contains the needed information in their observations."
    #     "Note:"
    #     "1) The action in a step will be executed and reflected in the observation in the next step. For example, if the action is 'click on the home page button', the observation in the next step will be the home page."
    #     "2) The action you see at each step may contain a number, like 'click[1316]'. This number is the index of the element in the observation. You may not know which element is clicked, but you can still use the reason to infer what that element is."
    #     "3) Analyze whehther each step should be ."
    #     "3) Your result JSON list must be consistent with your analysis."
    # )
    SYSTEM = (
        "You are an agent for information extraction. "
        "Given a objective (which specifies what information to be found in the web environment) and a list of web navigation agent interaction history (with reason, action, and observation summary), "
        "you responsibility is to select the step numbers whose webpage observations at needed are most likely to contain the information in the objective. "
        "After you select these steps, the webpage observations of these steps will be used to extract the information."
        "Some steps are only for navigating to the relvant pages. You should not select these steps even they may have relevant information, they have not start collecting the relevant information. (One exmaple is adjusting the filters before starting collection in a relevant page.)"
        "Some other steps are for collecting the information. These steps should be selected."
        "Note:"
        "1) The action in a step will be executed and reflected in the observation in the next step. For example, if the action is 'click on the home page button', the observation in the next step will be the home page."
        "2) The action you see at each step may contain a number, like 'click[1316]'. This number is the index of the element in the observation. You may not know which element is clicked, but you can still use the reason to infer what that element is."
        "3) In your output, analyze each step in one or two sentences first. After this, return a JSON list of step numbers (e.g., [2, 5, 7]) that you believe contains the needed information in their observations."
        "4) Your result JSON list must be consistent with your analysis."
    )
    USER = textwrap.dedent(f"""
        [OBJECTIVE]
        {user_goal}

        [STEPS]
    """)
    for s in step_summaries:
        USER += f"Step {s['step']}\nReason: {s['reason']}\nAction: {s['action']}\nObservation Summary: {s['observation_summary']}\n\n"
    messages = [{"role": "user", "content": [{"type": "text", "text": USER}]}]
    raw = run_llm(SYSTEM, messages, model_id=model_id)
    print(raw)
    # The output may contain analysis and a JSON list of step numbers. Extract the list robustly.
    step_candidates = _grab_json(raw)
    if not isinstance(step_candidates, list):
        # Try to find the last JSON list in the output (in case analysis comes before or after)
        matches = re.findall(r"\[[^\]]*\]", raw)
        for m in reversed(matches):
            try:
                parsed = json.loads(m)
                if isinstance(parsed, list):
                    step_candidates = parsed
                    break
            except Exception:
                continue
    if not isinstance(step_candidates, list):
        print("LLM did not return a valid list of step numbers. Aborting.")
        print("Raw output was:\n", raw)
        error_msg = "LLM did not return a valid list of step numbers. Aborting. Raw output was:\n" + raw  
        return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}

    print(f"\nFirst round relevant steps: {step_candidates}\n")

    # 5. Second LLM call: show full observations for selected steps, ask for confirmation
    # SYSTEM2 = (
    #     "You are a judge agent in a web navigation and information seeking task. "
    #     "Given a navigation objective (which includes the information to be found in the web environment) and the full observations for the steps you previously selected, "
    #     "analyze each step in one or two sentences to determine if its observation actually contains the information specified in the objective. "
    #     "After your analysis, return a JSON list of step numbers (e.g., [2, 5, 7]) that you believe truly contain the needed information in their observations. "
    # )
    # USER2 = textwrap.dedent(f"""
    #     [OBJECTIVE]
    #     {user_goal}

    #     [SELECTED STEPS WITH OBSERVATIONS]
    # """)
    stepnum_to_obs = {}

    for idx in step_candidates:
        obs_text = steps[idx].get("observation", "")
        # obs_text = get_observation_from_url(url=steps[idx].get("url", ""))
        stepnum_to_obs[idx] = obs_text
    #     USER2 += f"Step {idx}\nObservation:\n{obs_text}\n\n"

    # messages2 = [{"role": "user", "content": [{"type": "text", "text": USER2}]}]
    # raw2 = run_llm(SYSTEM2, messages2, model_id=model_id)
    # print(raw2)
    # # The output may contain analysis and a JSON list of step numbers. Extract the list robustly.
    # final_steps = _grab_json(raw2)
    # if not isinstance(final_steps, list):
    #     matches = re.findall(r"\[[^\]]*\]", raw2)
    #     for m in reversed(matches):
    #         try:
    #             parsed = json.loads(m)
    #             if isinstance(parsed, list):
    #                 final_steps = parsed
    #                 break
    #         except Exception:
    #             continue
    # if not isinstance(final_steps, list):
    #     print("LLM did not return a valid list of step numbers in round 2. Aborting.")
    #     print("Raw output was:\n", raw2)
    #     error_msg = "LLM did not return a valid list of step numbers in round 2. Aborting. Raw output was:\n" + raw2
    #     return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}
        
    # print(f"\nFinal relevant steps after confirmation: {final_steps}\n")

    # if not final_steps:
    #     print("No relevant steps selected by LLM.")
    #     error_msg = "No relevant steps selected by LLM."
    #     return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}

    final_steps = step_candidates  # For now, use the first round results directly
    selected_obs = [(s, stepnum_to_obs[s]) for s in final_steps if s in stepnum_to_obs]
    if not selected_obs:
        print("No valid observations found for selected steps.")
        error_msg = "No valid observations found for selected steps."
        return {"task_id": task_id, "analysis_result": error_msg, "eval": record.get("eval", {})}


    sample_obs = selected_obs[-1][1]
    extractor_prompt = get_extractor_prompt(user_goal, sample_obs, model_id)
    extractor_prompt = extractor_prompt + "\nNOTE: 1) Make sure you extract ALL the relevant items in the page without missing any product information. 2) Do not extract any information that is not related to the user goal. If you believe the page is inrrelevant to the user goal, return an empty list."

    # TODO: The extraction may be different for different pages to some tasks (shopping fees).

    # extractor_prompt = """
    # You are given the accessibility tree of a web page. Extract all product prices listed under the current web page.
    # Return the output as a JSON array of objects, where each object contains the following key-value pairs: 'product_name' (string, the name of the product) and 'price' (float, the price of the product).
    # To locate the information, look for text nodes associated with product links (e.g., link elements with product names) followed by text elements containing prices.
    # NOTE: Make sure you extract ALL the relevant items in the page without missing any product information.
    # Example output: [{'product_name': 'ALOHA Organic Plant Based Chocolate Sea Salt Protein Shake', 'price': 28.24}, {'product_name': 'Quest Nutrition Cookies & Cream Hero Bar', "price": 50.58}]
    # """
    print("\n🔧 extractor prompt obtained:")
    print(extractor_prompt)
    extraction_results = []

    for step_num, obs_text in selected_obs:
        out = run_prompt_extractor(extractor_prompt, obs_text, model_id=model_id)

        if out:
            if isinstance(out, list):
                extraction_results.extend(out)
            else:
                extraction_results.append(out)
            print(f"✔ {len(out)} items extracted from step {step_num}")
        else:
            print(f"✖ nothing extracted from step {step_num}")

    try:
        # extraction_results = [dict(tupleized) for tupleized in set(tuple(item.items()) for item in extraction_results)]
        extraction_results = deduplicate_with_order(extraction_results)
    except Exception as e:
        print(f"Error deduplicating extraction results: {e}")
        # If deduplication fails, keep the original list
        extraction_results = extraction_results

    print(f"\n⚙️ Final extraction results: \n{extraction_results}")
    print(f"\n⚙️ {len(extraction_results)} elements in total")
 
    # extractor_code = get_extractor_code(user_goal, _flat_observation(sample_obs), model_id)
    # print("🔧 extractor code obtained")
    # print(extractor_code)

    # extraction_results = []
    # for f, txt in selected_obs:
    #     out = run_extractor(extractor_code, _flat_observation(txt))
    #     if out:
    #         extraction_results+=out; print(f"✔ {len(out)} items extracted from Step{f}")
    #     else:
    #         print(f"✖ nothing in Step{f}")

    analysis_objective = record.get("ana_obj")
    analysis_dict = tri_phase_analyze(analysis_objective, extraction_results)
    analysis_result = analysis_dict['answer']

    print("*" * 100)
    print(analysis_result)
    print("*" * 100)

    result_dict = {
        "task_id": task_id,
        "analysis_result": analysis_result,
        "eval": record.get("eval", {}),
        "extraction_prompt": extractor_prompt,
        "extraction_results": extraction_results,
        "analysis_code": analysis_dict['code'],
        "selected_steps": final_steps
    }

    return result_dict

if __name__ == "__main__":

    record_dir = "TwoStage_judge_naive_glm/reddit/auto/AgentOccam"
    # record_dir = "extraction_debug"
    # task_ids = [20000, 20001, 20002, 20003, 20004, 20010, 20011, 20012, 20013, 20014, 20020, 20021, 20022, 20023, 20024, 20030, 20031, 20032, 20033, 20034, 20040, 20041, 20042, 20043, 20044, 20050, 20051, 20052, 20053, 20054, 20060, 20061, 20062, 20063, 20064, 20070, 20071, 20072, 20073, 20074, 20080, 20081, 20082, 20083, 20084, 20090, 20091, 20100, 20101, 20102, 20103, 20104, 20110, 20111, 20112, 20113, 20114, 20120, 20121, 20122, 20123, 20124, 20130, 20131, 20132, 20133, 20134, 20140, 20141, 20142, 20143, 20144, 20150, 20151, 20152, 20153, 20154, 20160, 20161, 20162, 20170, 20171, 20172, 20173, 20174, 20180, 20181, 20182, 20183, 20184, 20190, 20191, 20192, 20193, 20194, 20200, 20201, 20202, 20203, 20204, 20210, 20211, 20220, 20221, 20222, 20223, 20224, 20230, 20231, 20232, 20233, 20240, 20241, 20242, 20243, 20244, 20245]
    task_ids = [30000, 30001, 30010, 30020, 30021, 30022, 30023, 30024, 30030, 30031, 30032, 30033, 30034, 30040, 30041, 30042, 30043, 30050, 30051, 30052, 30053, 30054, 30060, 30061, 30062, 30063, 30070, 30071, 30072, 30073, 30074, 30080, 30081, 30082, 30083, 30084, 30090, 30091, 30092, 30093, 30094, 30100, 30101, 30102, 30103, 30104, 30110, 30111, 30112, 30113, 30114, 30120, 30121, 30122, 30123, 30124, 30130, 30131, 30132, 30133, 30134, 30140, 30141, 30142, 30143, 30144, 30150, 30151, 30152, 30153, 30154, 30160, 30161, 30162, 30163, 30164, 30170, 30171, 30172, 30173, 30174, 30180, 30181, 30182, 30183, 30184, 30190, 30191, 30192, 30193, 30194]
    # task_ids = [30000]
    # task_ids = [20131,20132,20133,20134]
    model_id = "glm-4.5-air-fp8"

    res_dir = os.path.join(record_dir, "result")
    os.makedirs(res_dir, exist_ok=True)

    def process_single_task(tid):
        """Process a single task with all the original logic preserved"""
        print(f"Processing task {tid}...")
        if not os.path.exists(os.path.join(record_dir, str(tid) + ".json")):
            print(f"Task run {tid} does not exist. Skipping task {tid}.")
            return None
        if os.path.exists(os.path.join(res_dir, f"{tid}_second_stage_result.json")):
            print(f"Task {tid} already processed. Skipping.")
            return None

        try: 
            result = run_second_stage(record_dir, str(tid), model_id)
        except Exception as e:
            print(f"Error processing task {tid}: {e}")
            result = {"task_id": tid, "analysis_result": f"Error: {e}"}
        
        output_path = os.path.join(res_dir, f"{tid}_second_stage_result.json")
        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(result, f, indent=4)
        print(f"Result saved to {output_path}\n")
        return result

    # Filter out tasks that should be skipped
    tasks_to_process = []
    for tid in task_ids:
        if (os.path.exists(os.path.join(record_dir, str(tid) + ".json")) and 
            not os.path.exists(os.path.join(res_dir, f"{tid}_second_stage_result.json"))):
            tasks_to_process.append(tid)
        else:
            if not os.path.exists(os.path.join(record_dir, str(tid) + ".json")):
                print(f"Task run {tid} does not exist. Skipping task {tid}.")
            else:
                print(f"Task {tid} already processed. Skipping.")

    # Process remaining tasks concurrently
    with ThreadPoolExecutor(max_workers=32) as executor:
        futures = {executor.submit(process_single_task, tid): tid for tid in tasks_to_process}
        for future in tqdm.tqdm(as_completed(futures), total=len(tasks_to_process), desc="Processing tasks"):
            tid = futures[future]
            try:
                future.result()
            except Exception as e:
                print(f"Unexpected error processing task {tid}: {e}")
