import os
import time
import re
import argparse
import shutil, hashlib
import yaml
from AgentOccam.env import WebArenaEnvironmentWrapper
from AgentOccam.AgentOccam import AgentOccam
from webagents_step.utils.data_prep import *
from AgentOccam.prompts import AgentOccam_prompt
from AgentOccam.llms.claude import call_claude, call_claude_with_messages, arrange_message_for_claude
from AgentOccam.llms.mistral import call_mistral, call_mistral_with_messages, arrange_message_for_mistral
from AgentOccam.llms.cohere import call_cohere, call_cohere_with_messages, arrange_message_for_cohere
from AgentOccam.llms.llama import call_llama, call_llama_with_messages, arrange_message_for_llama
from AgentOccam.llms.titan import call_titan, call_titan_with_messages, arrange_message_for_titan
from AgentOccam.llms.gpt import call_gpt, call_gpt_with_messages, arrange_message_for_gpt
from AgentOccam.llms.gemini import call_gemini, call_gemini_with_messages, arrange_message_for_gemini
from AgentOccam.llms.glm import call_glm, call_glm_with_messages, arrange_message_for_glm
from functools import partial
import subprocess

# HEADER_RE = re.compile(
#     r'^###\s*Part\s*(\d)\s*–\s*([^\n]+)\n',      # header with part number & title
#     flags=re.MULTILINE
# )
HEADER_RE = re.compile(
    r'^\s*#*\s*Part\s+(\d+)\s*[–-]\s*(.+)$',
    re.MULTILINE
)


def extract_decomposed_task(text):
    """
    Parse a decomposition written with the headers

        ### Part 1 – Navigation & Collection
        ### Part 2 – Analysis
        ### Part 3 – Final Navigation   (optional)

    and return a dict of the form
    {
        "part1": {"title": "...", "body": "..."},
        "part2": {"title": "...", "body": "..."}
    }
    """
    # Find every header’s start position
    headers = [(m.start(), m.end(), m.group(1), m.group(2).strip())
               for m in HEADER_RE.finditer(text)]
    if not headers:
        raise ValueError(f"No valid part headers found.\n{text}")

    parts= {}
    for idx, (start, end, part_num, title) in enumerate(headers):
        # Body runs from end of this header to start of the next header (or EOF)
        body_start = end
        body_end = headers[idx + 1][0] if idx + 1 < len(headers) else len(text)
        body = text[body_start:body_end].strip()
        parts[f"part{part_num}"] = {
            "title": title,
            "body": body
        }

    return parts


MODEL_FAMILIES = ["claude", "mistral", "cohere", "llama", "titan", "gpt", "gemini", "o3", "o4-mini", "glm"]
CALL_MODEL_MAP = {
    "claude": call_claude,
    "mistral": call_mistral,
    "cohere": call_cohere,
    "llama": call_llama,
    "titan": call_titan,
    "gpt": call_gpt,
    "gemini": call_gemini,
    "o3": call_gpt,
    "o4-mini": call_gpt,
    "glm": call_glm,
}
CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP = {
    "claude": call_claude_with_messages,
    "mistral": call_mistral_with_messages,
    "cohere": call_cohere_with_messages,
    "llama": call_llama_with_messages,
    "titan": call_titan_with_messages,
    "gpt": call_gpt_with_messages,
    "gemini": call_gemini_with_messages,
    "o3": call_gpt_with_messages,
    "o4-mini": call_gpt_with_messages,
    "glm": call_glm_with_messages,
}
ARRANGE_MESSAGE_FOR_MODEL_MAP = {
    "claude": arrange_message_for_claude,
    "mistral": arrange_message_for_mistral,
    "cohere": arrange_message_for_cohere,
    "llama": arrange_message_for_llama,
    "titan": arrange_message_for_titan,
    "gpt": arrange_message_for_gpt,
    "gemini": arrange_message_for_gemini,
    "o3": arrange_message_for_gpt,
    "o4-mini": arrange_message_for_gpt,
    "glm": arrange_message_for_glm,
}

def run_llm(system: str, messages, model_id: str = None):
    """Call the chosen LLM and return its raw response."""
    model_id = model_id or "gpt-5-mini"
    model_family = [m for m in MODEL_FAMILIES if m in model_id][0]
    call_fn = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[model_family], model_id=model_id)
    arrange_fn = ARRANGE_MESSAGE_FOR_MODEL_MAP[model_family]
    return call_fn(system_prompt = system, messages = arrange_fn(messages))

def run():
    parser = argparse.ArgumentParser(
        description="Only the config file argument should be passed"
    )
    parser.add_argument(
        "--config", type=str, required=True, help="yaml config file location"
    )
    parser.add_argument(
        "--auth_dir", type=str, default=".auth", help="auth directory for cookies"
    )
    args = parser.parse_args()
    print("Config file:", args.config)
    print("Auth directory:", args.auth_dir)
    with open(args.config, "r") as file:
        config = DotDict(yaml.safe_load(file))

    if config.logging:
        if config.logname:
            dstdir = f"{config.logdir}/{config.logname}"
        else:
            dstdir = f"{config.logdir}/{time.strftime('%Y%m%d-%H%M%S')}"
        os.makedirs(dstdir, exist_ok=True)
        shutil.copyfile(args.config, os.path.join(dstdir, args.config.split("/")[-1]))
    random.seed(42)
    print(os.path.join(dstdir, args.config.split("/")[-1]))

    config_file_list = []

    task_ids = config.env.task_ids
    if hasattr(config.env, "relative_task_dir"):
        relative_task_dir = config.env.relative_task_dir
    else:
        relative_task_dir = "tasks"
    if task_ids == "all" or task_ids == ["all"]:
        task_ids = [filename[:-len(".json")] for filename in os.listdir(f"config_files/{relative_task_dir}") if filename.endswith(".json")]
    for task_id in task_ids:
        config_file_list.append(f"config_files/{relative_task_dir}/{task_id}.json")

    fullpage = config.env.fullpage if hasattr(config.env, "fullpage") else True
    current_viewport_only = not fullpage

    count = 0

    # subprocess.run("bash scripts/login_setup.sh", shell=True)
    for config_file in config_file_list:

        extraction_result: list | None = None
        analysis_result:   str  | None = None

        with open(config_file, "r") as f:
            task_config = json.load(f)
            print(f"Task {task_config['task_id']}.")
        if os.path.exists(os.path.join(dstdir, f"{task_config['task_id']}.json")):
            print(f"Skip {task_config['task_id']}.")
            continue
        subprocess.run(f"bash scripts/login_setup.sh {args.auth_dir}", shell=True)
        if task_config['task_id'] in list(range(600, 650))+list(range(681, 689)):
            print("Reddit post task. Sleep 30 mins.")
            time.sleep(1800)
        if task_config['task_id'] in [30170, 30171, 30172, 30173, 30174, 50000, 50001, 50002, 50010, 50011, 50012, 50013, 50014, 50020, 50021, 50022, 50023, 50024, 50030, 50031, 50032, 50033, 50034, 50040, 50050, 50051, 50052, 50060, 50061, 50062, 50063, 50064]:
            count += 1
            if count % 3 == 0:
                print("Reddit post task. Restart.")
                subprocess.run("bash scripts/reset/reset_reddit.sh", shell=True)
                time.sleep(1800)
        if config.agent.actor.current_observation.auto == True:
            if task_config["required_obs"] == "image":
                config.agent.actor.current_observation.type = ["text", "image"]
            else:
                config.agent.actor.current_observation.type = ["text"]

        env = WebArenaEnvironmentWrapper(config_file=config_file,
                                         max_browser_rows=config.env.max_browser_rows,
                                         max_steps=config.max_steps,
                                         slow_mo=1,
                                         observation_type="accessibility_tree",
                                         current_viewport_only=current_viewport_only,
                                         viewport_size={"width": 1920, "height": 1080},
                                         headless=config.env.headless,
                                         global_config=config,
                                         evaluate_at_end = False)

        website = task_config.get("sites", None)[0]
        if website == "wikipedia":
            website = task_config.get("sites", None)[1]
        
        agent = AgentOccam(
            prompt_dict={k: v for k, v in AgentOccam_prompt.__dict__.items() if isinstance(v, dict)},
            config=copy.deepcopy(config.agent),
            website=website,
            task_id=task_config['task_id'],
            task_type="pure_navigation"
        )
        objective = env.get_objective()

        if config.decomp_mode == "naive":
            decomposer_prompt_path = "AgentOccam/prompts/decomposer/two_stage_naive.txt"
        elif config.decomp_mode == "wise":
            decomposer_prompt_path = "AgentOccam/prompts/decomposer/two_stage_wise.txt"
        else:
            raise NotImplementedError
        
        with open(decomposer_prompt_path) as fp:
            system = fp.read()

        task_domain = config.env.relative_task_dir.split("/")[-1]
        tips_path = f"AgentOccam/prompts/tips/{task_domain}.txt"
        with open(tips_path) as f:
            tips = f.read()

        system = system + "\n\n" + tips

        user_prompt = (
            f"[USER TASK]\n{objective}"
        )

        messages = [("text", user_prompt)]

        decomposer_response = run_llm(system=system, messages=messages, model_id=config.agent.actor.model)

        parsed = extract_decomposed_task(decomposer_response)

        navigation_objective = parsed["part1"]['body']
        analysis_objective = parsed["part2"]['body']

        task_id = task_config['task_id']
        # task_obs_dir = os.path.join(obs_dump_dir, f"task_{task_id}")

        print("*" * 100)
        print("navigation objective:\n", navigation_objective)
        print("*" * 100)
        print("analysis_objective:\n", analysis_objective)
        print("*" * 100)

        try:
            print("Starting agent.act()...")
            status = agent.act(objective=navigation_objective, env=env)
            print("agent.act() completed successfully")
            # extraction_result = run_directory(task_obs_dir, navigation_objective)

            # analyze_retry = 5
            # for _ in range(analyze_retry):
            #     analysis_dict = tri_phase_analyze(analysis_objective, extraction_result)
            #     analysis_result = analysis_dict['answer']
            #     if isinstance(analysis_result, str) and (analysis_result.strip() == "" or "error" in analysis_result.lower()):
            #         continue
            #     else:
            #         break

            # print("*" * 100)
            # print("Analysis Result:\n", analysis_result)
            # print("*" * 100)

            env.close()
        except ValueError as e:
            print(f"ValueError caught: {e}")
            env.close()
            status = {"done": False, "reward": 0, "success": 0, "num_actions": None}
            continue
        except KeyError as e:
            print(f"KeyError caught: {e}")
            env.close()
            status = {"done": False, "reward": 0, "success": 0, "num_actions": None}
            continue
        except Exception as e:
            print(f"Unexpected error: {e}")
            import traceback
            traceback.print_exc()
            env.close()
            status = {"done": False, "reward": 0, "success": 0, "num_actions": None}
            continue

        if config.logging:
            with open(config_file, "r") as f:
                task_config = json.load(f)
            log_file = os.path.join(dstdir, f"{task_config['task_id']}.json")
            log_data = {
                "task": config_file,
                "id": task_config['task_id'],
                # "analysis_result": analysis_result,
                "eval": task_config['eval'],
                "nav_obj": navigation_objective,
                "ana_obj": analysis_objective,
                # "extraction_result": str(extraction_result),
                # "analysis_code": str(analysis_dict["code"]),
                "model": config.agent.actor.model if hasattr(config.agent, "actor") else config.agent.model_name,
                "type": config.agent.type,
                "trajectory": agent.get_trajectory(),
            }
            summary_file = os.path.join(dstdir, "summary.csv")
            summary_data = {
                "task": config_file,
                "task_id": task_config['task_id'],
                "model": config.agent.actor.model if hasattr(config.agent, "actor") else config.agent.model_name,
                "type": config.agent.type,
                "logfile": re.search(r"/([^/]+/[^/]+\.json)$", log_file).group(1),
            }

            if extraction_result is None:
                log_data["ext_result"] = "EXTRACTION_FAILED"
            if analysis_result is None:
                log_data["analysis_result"] = "ANALYSIS_FAILED"

            if status:
                summary_data.update(status)
            log_run(
                log_file=log_file,
                log_data=log_data,
                summary_file=summary_file,
                summary_data=summary_data,
            )
    
if __name__ == "__main__":
    run()
