from AgentOccam.obs_opt import parse_node_descendants, parse_node_ancestors, parse_node_siblings, action_set_invisible, action_set_visible, action_set_visible_if_with_name, translate_node_to_str, construct_new_DOM_with_visible_nodes
from AgentOccam.llms.claude import call_claude, call_claude_with_messages, arrange_message_for_claude
from AgentOccam.llms.mistral import call_mistral, call_mistral_with_messages, arrange_message_for_mistral
from AgentOccam.llms.cohere import call_cohere, call_cohere_with_messages, arrange_message_for_cohere
from AgentOccam.llms.llama import call_llama, call_llama_with_messages, arrange_message_for_llama
from AgentOccam.llms.titan import call_titan, call_titan_with_messages, arrange_message_for_titan
from AgentOccam.llms.gpt import call_gpt, call_gpt_with_messages, arrange_message_for_gpt
from AgentOccam.llms.gemini import call_gemini, call_gemini_with_messages, arrange_message_for_gemini
from AgentOccam.llms.glm import call_glm, call_glm_with_messages, arrange_message_for_glm
from AgentOccam.utils import CURRENT_DIR, HOMEPAGE_URL

from typing import Dict
import re
import copy
import os, shutil
from functools import partial
import random
import json

import warnings
warnings.filterwarnings("ignore")


DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS = ["observation", "action"]
DEFAULT_ONLINE_INTERACTION_ELEMENTS = ["url", "observation"]
MODEL_FAMILIES = ["claude", "mistral", "cohere", "llama", "titan", "gpt", "gemini", "o3", "o4-mini", "glm"]
CALL_MODEL_MAP = {
    "claude": call_claude,
    "mistral": call_mistral,
    "cohere": call_cohere,
    "llama": call_llama,
    "titan": call_titan,
    "gpt": call_gpt,
    "gemini": call_gemini,
    "o3": call_gpt,
    "o4-mini": call_gpt,
    "glm": call_glm,
}
CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP = {
    "claude": call_claude_with_messages,
    "mistral": call_mistral_with_messages,
    "cohere": call_cohere_with_messages,
    "llama": call_llama_with_messages,
    "titan": call_titan_with_messages,
    "gpt": call_gpt_with_messages,
    "gemini": call_gemini_with_messages,
    "o3": call_gpt_with_messages,
    "o4-mini": call_gpt_with_messages,
    "glm": call_glm_with_messages,
}
ARRANGE_MESSAGE_FOR_MODEL_MAP = {
    "claude": arrange_message_for_claude,
    "mistral": arrange_message_for_mistral,
    "cohere": arrange_message_for_cohere,
    "llama": arrange_message_for_llama,
    "titan": arrange_message_for_titan,
    "gpt": arrange_message_for_gpt,
    "gemini": arrange_message_for_gemini,
    "o3": arrange_message_for_gpt,
    "o4-mini": arrange_message_for_gpt,
    "glm": arrange_message_for_glm,
}
_ANS_RE = re.compile(r"answer\s*:\s*(yes|no)", re.I)

class Agent:
    def __init__(self, config, objective, prompt_template):
        self.config = config
        self.objective = objective
        self.prompt_template = prompt_template
        print("config", config)
        if hasattr(self.config, "documented_interaction_elements"):
            self.previous_interactions = {k: [] for k in set(DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS+self.config.documented_interaction_elements)}
        else:
            self.previous_interactions = {k: [] for k in DEFAULT_DOCUMENTED_INTERACTION_ELEMENTS}
        if hasattr(self.config, "online_interaction_elements"):
            self.online_interaction = {k: None for k in set(DEFAULT_ONLINE_INTERACTION_ELEMENTS+self.config.online_interaction_elements)}
        else:
            self.online_interaction = {k: None for k in DEFAULT_ONLINE_INTERACTION_ELEMENTS}

        self.model_family = [model_family for model_family in MODEL_FAMILIES if model_family in self.config.model][0]
        self.call_model = partial(CALL_MODEL_MAP[self.model_family], model_id=self.config.model)
        self.call_model_with_message = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[self.model_family], model_id=self.config.model)
        self.arrange_message_for_model = ARRANGE_MESSAGE_FOR_MODEL_MAP[self.model_family]

    def shift_model(self, model_id):
        self.model_family = [model_family for model_family in MODEL_FAMILIES if model_family in model_id][0]
        self.call_model = partial(CALL_MODEL_MAP[self.model_family], model_id=model_id)
        self.call_model_with_message = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[self.model_family], model_id=model_id)
        self.arrange_message_for_model = ARRANGE_MESSAGE_FOR_MODEL_MAP[self.model_family]

    def prune_message_list(self, message_list):
        return self.merge_adjacent_text([m for m in message_list if not (m[0]=="text" and len(m[1])==0)])
    
    def merge_adjacent_text(self, message_list):
        merged_list = []
        current_tuple = None
        
        for tup in message_list:
            if tup[0] == "text":
                if current_tuple:
                    current_tuple = (current_tuple[0], current_tuple[1] + tup[1])
                else:
                    current_tuple = tup
            else:
                if current_tuple:
                    merged_list.append(current_tuple)
                    current_tuple = None
                merged_list.append(tup)
        
        if current_tuple:
            merged_list.append(current_tuple)
        
        return merged_list

    
    def get_step(self):
        return len(self.previous_interactions["action"])

    def update_objective(self, objective):
        self.objective = objective

    def update_online_state(self, **online_states):
        for k in online_states.keys():
            if k in self.online_interaction.keys():
                self.online_interaction[k] = online_states[k]

    def update_history(self, **interaction_dict):
        for k in interaction_dict.keys():
            if k in self.previous_interactions.keys():
                self.previous_interactions[k].append(interaction_dict[k])

    def equal_history_length(self):
        lengths = [len(self.previous_interactions[k]) for k in self.previous_interactions.keys()]
        return (len(set(lengths)) == 1)

    def parse_elements(self, text, key_list):
        element_dict = {}
        for k in key_list:
            # _match = re.search(rf'{k.upper()}:\s*(.*?)\s*(?=\n[A-Z\d\s\W]*: *\n|$)', text, re.DOTALL)
            _match = re.search(rf'{k.upper()}:\s*(.*?)\s*(?=\n[A-Z\s]*:|$)', text, re.DOTALL)
            element_dict[k] = _match.group(1).strip() if _match else ""
        return element_dict

    def get_output_specifications(self):
        output_specifications = "\n".join([f"{o.upper()}:\n" + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}.txt".format(o.replace(" ", "_"))), "r").readlines()) for o in self.config.output])
        return output_specifications

    def get_tips(self, website):
        with open(f"AgentOccam/prompts/tips/{website}.txt", "r") as f:
            tips = [line.strip() for line in f if line.strip()]
            tips = "\n".join(tips)
        return tips

    def parse_stipulated_action_list(self, text: str, action: str, actions: list) -> str:
        pattern = rf'({re.escape(action)}\s*(.*?))(?=\n(?:{"|".join(map(re.escape, actions))})|$)'
        return [match[0].strip() for match in re.findall(pattern, text, re.DOTALL)]

    def parse_str_to_action_list(self, text:str, actions: list):
        # import ipdb
        # ipdb.set_trace()
        remain_text = copy.deepcopy(text)
        action_list = []
        while remain_text:
            find_action = False
            for action in actions:
                if remain_text.startswith(action):
                    match = re.search(rf'({re.escape(action)}\s*(.*?))(?=\n(?:{"|".join(map(re.escape, actions))})|$)', remain_text, re.DOTALL)
                    action_list.append(match[0])
                    remain_text = remain_text[len(match[0]):].strip()
                    find_action = True
            if not find_action:
                break
        return action_list
    
    def get_observation_text(self, idx=None):
        if isinstance(self.online_interaction["observation"], dict):
            if idx:
                return self.previous_interactions["observation"][idx]["text"]
            return self.online_interaction["observation"]["text"]
        elif isinstance(self.online_interaction["observation"], str):
            if idx:
                return self.previous_interactions["observation"][idx]
            return self.online_interaction["observation"]
        
    def get_observation_image(self, idx=None):
        if isinstance(self.online_interaction["observation"], dict):
            if idx:
                return self.previous_interactions["observation"][idx]["image"]
            return self.online_interaction["observation"]["image"]
        elif isinstance(self.online_interaction["observation"], str):
            return None
        
    def get_observation_node(self, idx=None):
        if isinstance(self.online_interaction["observation"], dict):
            if idx != None:
                return self.previous_interactions["observation"][idx]["node"]
            return self.online_interaction["observation"]["node"]
        elif isinstance(self.online_interaction["observation"], str):
            return None
        
    def get_observation_node_str(self, idx=None):
        if isinstance(self.online_interaction["observation"], dict):
            if idx != None:
                return self.previous_interactions["observation"][idx]["node_str"]
            return translate_node_to_str(self.online_interaction["observation"]["node"], mode="name_only")
        elif isinstance(self.online_interaction["observation"], str):
            return None
        
    def del_observation_node(self):
        if isinstance(self.online_interaction["observation"], str):
            return
        if isinstance(self.online_interaction["observation"], dict):
            for idx in range(len(self.previous_interactions["observation"])):
                if "node" in self.previous_interactions["observation"][idx].keys() and self.previous_interactions["observation"][idx]["node"]:
                    node_str = translate_node_to_str(self.previous_interactions["observation"][idx]["node"], mode="name_only")
                    self.previous_interactions["observation"][idx]["node_str"] = node_str
                    self.previous_interactions["observation"][idx]["node"].delete_tree()
                    self.previous_interactions["observation"][idx]["node"] = None

class PlanTreeNode:
    def __init__(self, id, type, text, level, url, step):
        self.visible = True
        self.id = id
        self.type = type
        self.text = text
        self.level = level
        self.url = url
        self.step = step
        self.children = []
        self.parent = None
        self.note = []
        self.hint = []
        self.resume_reason = []
        self.steps_taken = []

    def reset(self):
        self.visible = True
        self.note = []
        self.hint = []
        self.steps_taken = []

    def add_child(self, child):
        child.parent = self
        self.children.append(child)

    def search_node_by_id(self, target_id):
        if self.visible and self.id == target_id:
            return self
        for child in self.children:
            result = child.search_node_by_id(target_id)
            if result:
                return result
        return None
    
    def traverse(self, action=None, tree_buffer=[]):
        res_action = action(self)
        if res_action:
            if isinstance(res_action, list):
                tree_buffer.extend(res_action)
            else:
                tree_buffer.append(res_action)
        for child in self.children:
            child.traverse(action, tree_buffer=tree_buffer)

class QAActor(Agent):
    def __init__(self, config, objective, prompt_template):
        super().__init__(config, objective, prompt_template)
    def get_instruction(self):
        return self.prompt_template["instruction_template"]
    def get_online_input(self):
        return [("text", self.prompt_template["input_template"].replace("{current_observation}", self.get_observation_text()).replace("{objective}", self.objective))]
    def get_action(self, instruction, online_input):
        model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
        action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
        action = action_elements["response"]
        action_elements["action"] = f"note [{action}]"
        action_elements["instruction"] = instruction
        action_elements["input"] = online_input
        return model_response, action_elements
    
class PlanningActor(Agent):
    def __init__(self, config, objective, prompt_template):
        super().__init__(config, objective, prompt_template)
        self.instruction = None

    def get_instruction(self):
        if self.instruction:
            return self.instruction
        output_specifications = self.get_output_specifications()
        self.instruction = self.prompt_template["instruction_template"].replace("{output_specifications}", output_specifications)
        return self.instruction
    
    def get_online_input(self):
        return None
    
    def get_action(self, instruction, online_input):
        model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
        action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
        action_elements["action"] = copy.deepcopy(action_elements["plan"])
        del action_elements["plan"]
        action_elements["reason"] = "N/A"
        action_elements["instruction"] = instruction
        action_elements["input"] = online_input
        return model_response, action_elements

class ReflectionActor(Agent):
    def __init__(self, config, objective, prompt_template):
        super().__init__(config, objective, prompt_template)
        self.instruction = None

    def get_navigation_specifications(self):
        return "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "navigation_specifications", f"{n}.txt"), "r").readlines()) for n in self.config.navigation_command])
    
    def get_instruction(self):
        if self.instruction:
            return self.instruction
        output_specifications = self.get_output_specifications()
        navigation_specifications = self.get_navigation_specifications()
        instruction = self.prompt_template["instruction_template"]
        instruction = instruction.replace("{output_specifications}", output_specifications)
        instruction = instruction.replace("{navigation_specifications}", navigation_specifications)
        self.instruction = instruction
        return self.instruction
    
    def get_online_input(self):
        return None
    
    def get_action(self, instruction, online_input):
        model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
        action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
        action_elements["instruction"] = instruction
        action_elements["input"] = online_input
        return model_response, action_elements

IDENTITY_CLASS_MAP = {
    "QA": QAActor,
    "planning": PlanningActor,
    "reflection": ReflectionActor,
}

class Actor(Agent):
    def __init__(self, config, objective, prompt_template, initial_plan=None):
        super().__init__(config, objective, prompt_template)
        self.initial_plan = initial_plan
        self.current_plan = initial_plan
        self.plan_progress = []
        self.output_specifications = None
        self.navigation_specifications = None
        self.criticism_element_list = None

        self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "play.txt")
        self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "trash.txt")

        self.identities = []
        if hasattr(self.config, "identities"):
            i = 0
            while hasattr(self.config.identities, f"identity_{i}"):
                identity_config = getattr(self.config.identities, f"identity_{i}")
                self.identities.append(IDENTITY_CLASS_MAP[identity_config.name](identity_config, objective=objective, prompt_template=prompt_template[identity_config.name]))
                i += 1
    
    def update_online_state(self, **online_states):
        super().update_online_state(**online_states)
        for identity in self.identities:
            identity.update_online_state(**online_states)

    # def update_plan(self, new_plan):
    #     """Update the current plan based on judge feedback"""
    #     self.current_plan = new_plan
    #     self.plan_progress.append({
    #         "step": self.get_step(),
    #         "old_plan": self.current_plan,
    #         "new_plan": new_plan,
    #         "reason": "Plan updated by judge"
    #     })

    def get_current_plan(self):
        """Get the current plan for display"""
        return self.current_plan if self.current_plan else "No plan available"

    def get_plan_progress(self):
        """Get plan progress history"""
        return self.plan_progress

    def is_navigation(self, action):
        action_without_note = re.sub(rf'(note\s*(.*?))(?=\n(?:{"|".join(map(re.escape, self.config.navigation_command))})|$)', "", action).strip()
        for c in self.config.navigation_command:
            if action_without_note.startswith(c):
                return c
        return False
    
    def is_valid_action(self, action_str):
        action = (
            action_str.split("[")[0].strip()
            if "[" in action_str
            else action_str.split()[0].strip()
        )
        match action:
            case "click":
                # match = re.search(r"click ?\[(\d+)\]", action_str)
                match = re.search(r"click ?\[?(\d+)\]?", action_str)
                if not match:
                    return False
                element_id = match.group(1)
                if element_id in self.get_observation_text():
                    return True
                return False
            case "type":
                if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
                    action_str += " [1]"

                match = re.search(
                    r"type ?\[(\d+)\] ?\[(.*)\] ?\[(\d+)\]", action_str, re.DOTALL
                )
                if not match:
                    return False
                element_id, text, enter_flag = (
                    match.group(1),
                    match.group(2),
                    match.group(3),
                )
                enter_flag = True if enter_flag == "1" else False
                if enter_flag:
                    text += "\n"
                if element_id in self.get_observation_text():
                    return True
            case "go_back":
                return True
            case "go_home":
                return True
            # case "note":
            #     return True
            case "stop":
                return True
            case "goto":
                return True
            case "scroll":
                return True

    def are_valid_actions(self, actions):
        action_list = self.parse_str_to_action_list(actions, self.config.navigation_command+["goto"])
        if not action_list:
            return False
        for action in action_list:
            if not self.is_valid_action(action):
                return False
        return True

    def get_previous_plans(self, verbose=False):
        """Get the current plan and plan progress history"""
        plan_text = f"Current Plan:\n{self.get_current_plan()}\n"
        if verbose and self.plan_progress:
            plan_text += "\nPlan Progress History:\n"
            for progress in self.plan_progress:
                plan_text += f"Step {progress['step']}: {progress['reason']}\n"
        return plan_text
    
    def get_active_plan(self):
        return self.get_current_plan()
    
    def get_interaction_history(self, interaction_history_config=False, mode="highlight"):
        interaction_history_config = interaction_history_config if interaction_history_config else self.config.interaction_history

        # Get all previous steps (simplified approach)
        all_steps = list(range(len(self.previous_interactions.get("action", []))))
        
        previous_observation = []
        for i in all_steps:
            if self.get_observation_node_str() and self.get_observation_node_str(i) and not self.get_observation_node_str() == self.get_observation_node_str(i):
                if self.previous_interactions["observation highlight"][i] and mode == "highlight" and len(translate_node_to_str(self.previous_interactions["observation highlight"][i], mode="name_only", retained_ids=self.previous_interactions["retained element ids"][i]).split()) < 200:
                    try:
                        previous_observation.append({"text": translate_node_to_str(self.previous_interactions["observation highlight"][i], mode="name_only", retained_ids=self.previous_interactions["retained element ids"][i]), "image": self.get_observation_image(i)})
                    except:
                        print(i, self.previous_interactions["observation"][i]["text"])
                        raise ValueError("Cannot translate highlight node to text.")
                else:
                    previous_observation.append({"text": self.previous_interactions["observation summary"][i], "image": self.get_observation_image(i)})
            elif not self.get_observation_node() or mode == "full":
                if len(self.get_observation_text(i).split()) < 200:
                    previous_observation.append({"text": self.get_observation_text(i), "image": self.get_observation_image(i)})
                else:
                    previous_observation.append({"text": self.previous_interactions["observation summary"][i], "image": self.get_observation_image(i)})
            else:
                previous_observation.append({"text": "The same as the CURRENT OBSERVATION (see below CURRENT OBSERVATION section).", "image": self.get_observation_image(i)})

        previous_observation_summary = [self.previous_interactions["observation summary"][i] for i in all_steps]

        def get_text(obs):
            if isinstance(obs, dict):
                return obs["text"]
            elif isinstance(obs, str):
                return obs

        def get_image(obs):
            if isinstance(obs, dict):
                return obs["image"]
            elif isinstance(obs, str):
                return obs

        if interaction_history_config.step_num == "all":
            textual_observations = [get_text(obs) for obs in previous_observation] if interaction_history_config.verbose else previous_observation_summary
            visual_observations = [get_image(obs) for obs in previous_observation]
        else:
            textual_observations = previous_observation_summary[:-interaction_history_config.step_num]
            visual_observations = [None] * len(previous_observation_summary[:-interaction_history_config.step_num])
            textual_observations += [get_text(obs) for obs in previous_observation][-interaction_history_config.step_num:] if interaction_history_config.verbose else previous_observation_summary[-interaction_history_config.step_num:]
            visual_observations += [get_image(obs) for obs in previous_observation][-interaction_history_config.step_num:]

        plans = [self.previous_interactions["plan"][i] for i in all_steps]
        reasons = [self.previous_interactions["reason"][i] for i in all_steps]
        actions = [self.previous_interactions["action"][i] for i in all_steps]
            
        if "image" in interaction_history_config.type:
            message_list = []
            for step, (obs, vi_obs, plan, reason, action) in enumerate(zip(textual_observations, visual_observations, plans, reasons, actions)):
                message_list.append(("text", f"<step_{step}_interaction>\n"))
                if vi_obs:
                    message_list.append(("text", "VISUAL OBSERVATION:\n"))
                    message_list.append(("image", vi_obs))
                
                # Always include plan information
                # message_list.append(("text", f"TEXTUAL OBSERVATION:\n{obs}\nCURRENT PLAN:\n{self.get_current_plan()}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"))
                message_list.append(("text", f"TEXTUAL OBSERVATION:\n{obs}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"))
        
            return self.prune_message_list(message_list=message_list)
        else:
            message = ""
            for step, (obs, plan, reason, action) in enumerate(zip(textual_observations, plans, reasons, actions)):
                message += f"<step_{step}_interaction>\nOBSERVATION:\n{obs}\nREASON FOR ACTION:\n{reason}\nACTION:\n{action}\n</step_{step}_interaction>\n"

            return self.prune_message_list(message_list=[("text", message)])
        
    def pre_process_atomic_actions(self, atomic_action_list=["combobox"]):
        if self.get_observation_node() and "combobox" in atomic_action_list:
            self.online_interaction["observation"]["text"] = translate_node_to_str(self.get_observation_node(), mode="concise", hidden_roles=["menu", "combobox", "listbox"])

    def get_online_input(self, criticism_elements):
        input_template = self.prompt_template["input_template"]
        input_prefix, input_suffix = input_template.split("{input}")
        INPUT_TYPE_TO_CONTENT_MAP = {
            "step": self.get_step(),
            "objective": self.objective,
            "current plan": self.get_current_plan(),
            "interaction history": self.get_interaction_history(),
            "current observation": self.get_observation_text(),
            "current visual observation": self.get_observation_image()
        }
        input_list = []
        for input_type in self.config.input:
            input_content = None
            if input_type == "current visual observation":
                continue
            elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
                input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
            elif input_type.startswith("critic: ") and criticism_elements and input_type[len("critic: "):] in criticism_elements.keys() and criticism_elements[input_type[len("critic: "):]]:
                input_type = input_type[len("critic: "):]
                input_content = criticism_elements[input_type]
                input_type = "FROM USER: " + input_type
            if input_content and isinstance(input_content, str):
                input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
            elif input_content and isinstance(input_content, list):
                input_list.append(("text", f"{input_type.upper()}:\n"))
                input_list += input_content if len(input_content) > 0 else ["N/A"]

        if "image" in self.config.current_observation.type:
            input_type = "current visual observation"
            input_list.append(("text", f"{input_type.upper()}:\n"))
            input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))

        return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])
    

    
    def get_navigation_specifications(self):
        if self.navigation_specifications:
            return self.navigation_specifications
        self.navigation_specifications = "\n".join(["- " + "".join(open(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "navigation_specifications", f"{n}.txt"), "r").readlines()) for n in self.config.navigation_command])
        return self.navigation_specifications
    

    def get_actor_instruction(self, examples=None, website=None, task_type=None):
        if task_type is not None:
            instruction = self.prompt_template["instruction_template"][task_type]
        elif self.config.tips:
            instruction = self.prompt_template["instruction_template"]["with_planning_w_tips"]
        else:
            instruction = self.prompt_template["instruction_template"]["without_planning"]
        output_specifications = self.get_output_specifications()
        navigation_specifications = self.get_navigation_specifications()
        instruction = instruction.replace("{output_specifications}", output_specifications)
        instruction = instruction.replace("{navigation_specifications}", navigation_specifications)
        if self.config.tips and website:
            website_tips = self.get_tips(website)
            instruction = instruction.replace("{website_tips}", website_tips)
        # print("instruction", instruction)
        example_source = examples if examples is not None else self.prompt_template.get("examples", [])
        if len(example_source) > 0:
            instruction += "\n\n## Here are a few examples:"
            for i, example in enumerate(example_source):
                example_input = example["input"]
                example_output = example["output"]
                if "example_template" in self.prompt_template.keys():
                    instruction += "\n\n"
                    instruction += self.prompt_template.get("example_template", "| Example {i}\n### Input:\n{example_input}\n### Response: Let's think step by step.\n{example_response}").replace("{i}", i).replace("{example_input}", example_input).replace("{example_output}", example_output)
                else:
                    instruction += f"\n\n| Example {i}\n\n### Input:\n{example_input}\n\n### Response: Let's think step by step.\n{example_output}"
        
        if self.get_step() == self.config.others.max_steps - 1:
            instruction += f"\n\nWARNING: You have a {self.config.others.max_steps}-step budget, and this would be your FINAL STEP. Wrap up your observations and return your answer with `stop [answer]` to maximize the reward."
        # else:
        #     instruction += f"\n\nWARNING: You have a {self.config.others.max_steps}-step budget, and there are {self.config.others.max_steps-self.get_step()} remaining attempts."

        return instruction
    
    def verbose(self, instruction, online_input, model_response_list, action_element_list):
        action_element_keys = [k for k in self.config.play if k in action_element_list[0].keys()]
        other_play_keys = [k for k in self.config.play if k not in action_element_list[0].keys()]

        VERBOSE_TO_CONTENT_MAP = {
            "step": self.get_step(),
            "objective": self.objective,
            "previous plans": self.get_previous_plans(verbose=True),
            "url": self.online_interaction["url"],
            "observation": self.get_observation_text(),
            "response": "\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n".join([f"|\tAgent {i}:\n{model_response}" for i, model_response in enumerate(model_response_list[:self.config.number])]) if self.config.number > 1 else model_response_list[0],
            "instruction": instruction,
            "online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
            "alter ego response": "\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n".join(["|\tAgent {}:\n{}".format(identity.config.name, response) for identity, response in zip(self.identities, model_response_list[self.config.number:])])
        }

        if self.config.others.verbose > 0 and self.config.verbose > 0:
            with open(self.output_trash_path, "a") as af:
                af.write("-"*32+"ACTOR"+"-"*32+"\n")
            for t in self.config.trash:
                content = VERBOSE_TO_CONTENT_MAP.get(t, "")
                with open(self.output_trash_path, "a") as af:
                    af.write(f"{t.upper()}:\n{content}\n\n")
            with open(self.output_play_path, "w") as _:
                pass
            for p in other_play_keys:
                content = VERBOSE_TO_CONTENT_MAP.get(p, "")
                with open(self.output_play_path, "a") as af:
                    af.write(f"{p.upper()}:\n{content}\n\n")
            for i, action_elements in enumerate(action_element_list):
                if len(action_element_list) > 1:
                    with open(self.output_play_path, "a") as af:
                        af.write("-"*32+f"AGENT {i}"+"-"*32+"\n")
                for action_element_key in action_element_keys:
                    content = action_elements.get(action_element_key, "N/A")
                    with open(self.output_play_path, "a") as af:
                        af.write(f"{action_element_key.upper()}:\n{content}\n\n")
    
    # def update_plan_from_judge(self, new_plan):
    #     """Update plan based on judge feedback"""
    #     if new_plan and new_plan != self.current_plan:
    #         self.update_plan(new_plan)
    #         return True
    #     return False
    
    def go_home(self, action):
        if "go_home" in action:
            return f"goto [{HOMEPAGE_URL}] [1]"
        return None
    
    def parse_action(self, action_str):
        try:
            DOM_root_node = self.get_observation_node()
            action_str = action_str.strip()
            action = (
                action_str.split("[")[0].strip()
                if "[" in action_str
                else action_str.split()[0].strip()
            )
            match action:
                case "click":
                    match = re.search(r"click ?\[(\d+)\]", action_str)
                    if not match:
                        raise ValueError(f"Invalid click action {action_str}")
                    element_id = match.group(1)
                    node = DOM_root_node.search_node_by_id(element_id)
                    return f"click [{element_id}] ({node.role} {node.name})"
                case "hover":
                    match = re.search(r"hover ?\[(\d+)\]", action_str)
                    if not match:
                        raise ValueError(f"Invalid hover action {action_str}")
                    element_id = match.group(1)
                    node = DOM_root_node.search_node_by_id(element_id)
                    return f"hover [{element_id}] ({node.role} {node.name})"
                case "type":
                    if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
                        action_str += " [1]"

                    match = re.search(
                        r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
                    )
                    if not match:
                        raise ValueError(f"Invalid type action {action_str}")
                    element_id, text, enter_flag = (
                        match.group(1),
                        match.group(2),
                        match.group(3),
                    )
                    enter_flag = True if enter_flag == "1" else False
                    if enter_flag:
                        text += "\n"
                    node = DOM_root_node.search_node_by_id(element_id)
                    return action + f" ({node.name})"
                case "scroll":
                    return action_str
                case "goto":
                    return action
                case "new_tab":
                    return action
                case "go_back":
                    return action
                case "go_forward":
                    return action
                case "stop":
                    return action

            return False
        except:
            return False
    
    def parse_actions_to_element_ids(self, actions):
        action_str_list = []
        for a in self.config.navigation_command:
            action_str_list += self.parse_stipulated_action_list(text=actions, action=a, actions=self.config.navigation_command+["goto"])
        retained_element_ids = []
        for action_str in action_str_list:
            try:
                action_str = action_str.strip()
                action = (
                    action_str.split("[")[0].strip()
                    if "[" in action_str
                    else action_str.split()[0].strip()
                )
                match action:
                    case "click":
                        match = re.search(r"click ?\[(\d+)\]", action_str)
                        if not match:
                            raise ValueError(f"Invalid click action {action_str}")
                        element_id = match.group(1)
                        element_id = int(element_id)
                        retained_element_ids.append(element_id)
                    case "hover":
                        match = re.search(r"hover ?\[(\d+)\]", action_str)
                        if not match:
                            raise ValueError(f"Invalid hover action {action_str}")
                        element_id = match.group(1)
                        element_id = int(element_id)
                        retained_element_ids.append(element_id)
                    case "type":
                        if not (action_str.endswith("[0]") or action_str.endswith("[1]")):
                            action_str += " [1]"

                        match = re.search(
                            r"type ?\[(\d+)\] ?\[(.+)\] ?\[(\d+)\]", action_str
                        )
                        if not match:
                            raise ValueError(f"Invalid type action {action_str}")
                        element_id, text, enter_flag = (
                            match.group(1),
                            match.group(2),
                            match.group(3),
                        )
                        element_id = int(element_id)
                        retained_element_ids.append(element_id)
                    case "scroll":
                        pass
                    case "goto":
                        pass
                    case "new_tab":
                        pass
                    case "go_back":
                        pass
                    case "go_forward":
                        pass
                    case "stop":
                        pass
                    case "note":
                        pass

                return retained_element_ids
            except:
                continue

        return retained_element_ids
        
    def get_observation_highlight(self, action_elements:dict):
        action_elements["observation highlight idxs"] = copy.deepcopy(action_elements.get("observation highlight", ""))
        DOM_root_node = self.get_observation_node()
        if not DOM_root_node:
            action_elements["observation highlight"] = None
            return
        observation_highlight_idxs = [int(idx.strip()) for idx in action_elements.get("observation highlight", "").split(",") if idx.strip().isdigit()]
        if observation_highlight_idxs:
            parse_node_descendants(node=DOM_root_node, action=action_set_invisible)
            for idx in observation_highlight_idxs:
                try:
                    node = DOM_root_node.search_node_by_id(idx)
                    parse_node_descendants(node=node, action=action_set_visible)
                    parse_node_ancestors(node=node, action=action_set_visible)
                    parse_node_siblings(node=node, action=action_set_visible_if_with_name)
                except:
                    pass
        try: 
            assert DOM_root_node.get_visible_node_number() < 30 and construct_new_DOM_with_visible_nodes(DOM_root=DOM_root_node)
            action_elements["observation highlight"] = construct_new_DOM_with_visible_nodes(DOM_root=DOM_root_node)
            parse_node_descendants(node=DOM_root_node, action=action_set_visible)
        except:
            parse_node_descendants(node=DOM_root_node, action=action_set_visible)
            action_elements["observation highlight"] = None

        action_elements["retained element ids"] = self.parse_actions_to_element_ids(action_elements["action"])

    def parse_action_from_action_candidates(self, action_elements):
        if "action" in action_elements.keys():
            return action_elements
        assert any("action candidates" in k for k in action_elements.keys())
        action_candidates_key = [k for k in action_elements.keys() if "action candidates" in k][0]
        def parse_reasons_and_actions(input_string):
            # pattern = r'- reason: \[(.*?)\]\s*(?:- action: \[(.*?)\])?\s*(?:\n|\Z)'
            pattern = r'- reason: \[?(.*?)\]?\s*(?:- action: \[?(.*?)\]?)?\s*(?:\n|\Z)'
            matches = re.findall(pattern, input_string, re.DOTALL)

            parsed_data = []
            for match in matches:
                reason = match[0].strip()
                action = match[1].strip()

                bracket_actions = ["click", "type", "stop"]
                if action.startswith(tuple(bracket_actions)) and not action.endswith("]"):
                    action += "]"

                if reason and action:
                    parsed_data.append({'reason': reason, 'action': action})

            return parsed_data
        action_elements[action_candidates_key] = parse_reasons_and_actions(action_elements[action_candidates_key])
        return action_elements

    def predict_action(self, criticism_elements, website, task_type):
        # import ipdb
        # ipdb.set_trace()
        if self.config.debug > 1:
            action_elements = {k: "" for k in self.config.output}
            human_input = input("ACTION: ")
            action_elements["action"] = human_input
            return [action_elements]
        
        self.pre_process_atomic_actions()
        instruction = self.get_actor_instruction(website=website, task_type=task_type)
        online_input = self.get_online_input(criticism_elements=criticism_elements) #Include objective and current observation

        # print("*" * 100)
        # print("Instruction:\n", instruction)
        # print("Online Input:\n", "\n".join([i[1] for i in online_input if i[0]=="text"]))
        # print("*" * 100)

        model_response_list = []
        action_element_list = []
        for _ in range(self.config.number):
            get_valid_actions = False
            repetitive_note = False
            invalid_actions = False
            repetitive_note_count = 0
            invalid_actions_count = 0
            while not get_valid_actions:
                if repetitive_note_count > 50:
                    raise ValueError("Too many repetitive notes.")
                if invalid_actions_count > 50:
                    raise ValueError("Too many invalid actions.")
                if repetitive_note:
                    model_response = self.call_model_with_message(system_prompt=instruction+"\nGenerating the command `note [{}]` will be severely punished! Don't generate repetitive notes!".format(getattr(self, "note_buffer", "")), messages=self.arrange_message_for_model(online_input))
                elif invalid_actions:
                    model_response = self.call_model_with_message(system_prompt=instruction+"\nGenerating the command `{}` will be severely punished! Don't generate invalid actions! We don't have that element id in the current observation!".format(invalid_action_str), messages=self.arrange_message_for_model(online_input))
                else:
                    model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))

                print(model_response)
                action_elements = self.parse_elements(text=model_response, key_list=self.config.output)
                # self.config.output = ['interaction history summary', 'observation description', 'reason', 'action', 'observation highlight']
                action_elements = self.parse_action_from_action_candidates(action_elements=action_elements)
                assert not ("action" in action_elements.keys() and any("action candidates" in k for k in action_elements.keys()))
                if "action" in action_elements.keys():
                    if self.are_valid_actions(action_elements["action"]):
                        note_buffer = getattr(self, "note_buffer", "")
                        if note_buffer and f"note [{note_buffer}" in action_elements["action"]:
                            print(f"Repetitive note: {note_buffer}")
                            repetitive_note = True
                            repetitive_note_count += 1
                            continue
                        get_valid_actions = True
                        action_elements["input"] = online_input
                        model_response_list.append(model_response)
                        action_element_list.append(action_elements)
                    else:
                        invalid_action_str = action_elements["action"]
                        print(f"Invalid actions: {invalid_action_str}")
                        invalid_actions = True
                        invalid_actions_count += 1
                elif any("action candidates" in k for k in action_elements.keys()):
                    action_candidates_key = [k for k in action_elements.keys() if "action candidates" in k][0]
                    if isinstance(action_elements[action_candidates_key], str):
                        continue
                    filtered_action_candidates = []
                    note_buffer = getattr(self, "note_buffer", "")
                    for action_reason_pair in action_elements[action_candidates_key]:
                        action = action_reason_pair["action"]
                        reason = action_reason_pair["reason"]
                        if self.are_valid_actions(action):
                            if note_buffer and f"note [{note_buffer}" in action:
                                print(f"Repetitive note: {note_buffer}")
                                repetitive_note = True
                                repetitive_note_count += 1
                                continue
                            filtered_action_candidates.append({'reason': reason, 'action': action})
                        else:
                            invalid_action_str = action
                            print(f"Invalid actions: {invalid_action_str}")
                            invalid_actions = True
                            invalid_actions_count += 1
                    if filtered_action_candidates:
                        action_elements[action_candidates_key] = filtered_action_candidates
                        get_valid_actions = True
                        action_elements["input"] = online_input
                        model_response_list.append(model_response)
                        action_element_list.append(action_elements)
                else:
                    raise NotImplementedError("You have to generate either action or action candidates.")
        # if self.config.number != 1:
        if True:
            for identity in self.identities:
                identity_instruction = identity.get_instruction() if identity.get_instruction() else instruction
                identity_online_input = identity.get_online_input() if identity.get_online_input() else online_input
                get_valid_actions = False
                invalid_actions = False
                while not get_valid_actions:
                    if invalid_actions:
                        model_response, action_elements = identity.get_action(identity_instruction+"\nGenerating the command `{}` will be severely punished! Don't generate invalid actions! We don't have that element id in the current observation!".format(invalid_action_str), identity_online_input)
                    else:
                        model_response, action_elements = identity.get_action(identity_instruction, identity_online_input)      
                    if self.are_valid_actions(action_elements["action"]):
                        get_valid_actions = True
                        model_response_list.append(model_response)
                        action_element_list.append(action_elements)
                    else:
                        invalid_action_str = action_elements["action"]
                        print(f"Invalid actions: {invalid_action_str}")
                        invalid_actions = True
        
        self.verbose(instruction=instruction, online_input=online_input, model_response_list=model_response_list, action_element_list=action_element_list)
        # keep record

        if self.config.others.debug or self.config.debug:
            for i in range(len(action_element_list)):
                human_input = input(f"ACTION {i}: ")
                if human_input != "":
                    action_element_list[i]["action"] = human_input

        return action_element_list
    
    def _page_has_needed_info(self, obs_text: str, extraction_objective: str) -> bool:
        system = ("You are a yes/no classifier for extracting necessary information for a user task.\n"
                "Given the task objective for information extraction and a webpage accessibility tree text, you need to think step-by-step about whether the page contains exact "
                "information that would help accomplish the OBJECTIVE.\n."
                "Then output two lines exactly:\n"
                "reason: <your short explanation>\n"
                "answer: yes|no")

        prompt = f"[Extraction OBJECTIVE]\n{extraction_objective}\n\n[PAGE]\n{obs_text}"
        print("*" * 100)
        print(prompt)
        print("*" * 100)
        messages   = [{"role": "user", "content": [{"type":"text","text":prompt}]}]

        retries = 3

        for attempt in range(retries):
            raw =self.call_model_with_message(system_prompt=system, messages=messages).strip()
            m   = _ANS_RE.search(raw)
            if not m:
                # couldn't parse, ask again
                continue

            answer = m.group(1).lower()
            reason = raw.splitlines()[0].replace("reason:", "").strip()

            if answer == "yes":
                print(f"Relevance Judgement: {answer}\n{reason}")
                return True
            elif answer == "no":
                print(f"Relevance Judgement: {answer}\n{reason}")
                return False
            else: 
                continue

        # default fallback
        return False


    def finalize_action(self, action_elements):
        self.get_observation_highlight(action_elements=action_elements)
        action = action_elements["action"]
        
        # Judge whether it is a go home action
        navigation_action = self.go_home(action=action)
        if navigation_action:
            action_elements["navigation action"] = navigation_action
        return action_elements

class Critic(Agent):
    def __init__(self, config, objective, prompt_template):
        super().__init__(config, objective, prompt_template)
        self.instruction = None
        self.actor_basic_info_dict = None

        self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "play.txt")
        self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "trash.txt")

    def verbose(self, instruction, online_input, model_response):
        VERBOSE_TO_CONTENT_MAP = {
            "url": self.online_interaction["url"],
            "objective": self.objective,
            "instruction": instruction,
            "online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
            "response": model_response
        }
        if self.config.others.verbose > 0 and self.config.verbose > 0:
            with open(self.output_trash_path, "a") as af:
                af.write("-"*32+"CRITIC"+"-"*32+"\n")
            for t in self.config.trash:
                content = VERBOSE_TO_CONTENT_MAP[t]
                with open(self.output_trash_path, "a") as af:
                    af.write(f"{t.upper()}:\n{content}\n\n")

    def update_actor_basic_info(self, **actor_basic_info_dict):
        self.actor_basic_info_dict = actor_basic_info_dict

    def get_output_specifications(self):
        output_specification_filepath_list = []
        for o in self.config.output:
            if os.path.exists(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}_{}.txt".format(o.replace(" ", "_"), self.config.character))):
                output_specification_filepath_list.append(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}_{}.txt".format(o.replace(" ", "_"), self.config.character)))
            else:
                output_specification_filepath_list.append(os.path.join(CURRENT_DIR, "AgentOccam", "prompts", "output_specifications", "{}.txt".format(o.replace(" ", "_"))))
        output_specifications = "\n".join([f"{o.upper()}:\n" + "".join(open(filepath, "r").readlines()) for o, filepath in zip(self.config.output, output_specification_filepath_list)])
        return output_specifications

    def get_critic_instruction(self, website=None):
        if self.instruction:
            return self.instruction
        instruction = self.prompt_template["instruction_template"]
        output_specifications = self.get_output_specifications()
        if self.config.tips and website:
            website_tips = self.get_tips(website)
        instruction = instruction.replace("{output_specifications}", output_specifications)
        instruction = instruction.replace("{navigation_specifications}", self.actor_basic_info_dict["navigation_specifications"])
        instruction = instruction.replace("{website_tips}", website_tips)
        self.instruction = instruction
        print("-------------------------------------------")
        print("self.instruction", self.instruction)
        print("-------------------------------------------")
        return self.instruction
    
    def get_online_input(self):
        input_template = self.prompt_template["input_template"]
        input_prefix, input_suffix = input_template.split("{input}")
        # ["objective", "previous plans", "interaction history", "step", "current observation"]
        INPUT_TYPE_TO_CONTENT_MAP = {
            "step": self.actor_basic_info_dict["step"],
            "objective": self.objective,
            "current plan": self.actor_basic_info_dict.get("current_plan", "No plan available"),
            "interaction history": self.actor_basic_info_dict["interaction_history"],
            "current observation": self.get_observation_text(),
            "current visual observation": self.get_observation_image()
        }
        input_list = []
        for input_type in self.config.input:
            input_content = None
            if input_type == "current visual observation":
                continue
            elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
                input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
            if input_content and isinstance(input_content, str):
                input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
            elif input_content and isinstance(input_content, list):
                input_list.append(("text", f"{input_type.upper()}:\n"))
                input_list += input_content if len(input_content) > 0 else ["N/A"]

        if "image" in self.config.current_observation.type:
            input_type = "current visual observation"
            input_list.append(("text", f"{input_type.upper()}:\n"))
            input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))

        return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])

    def get_criticism_elements(self):
        if not self.config.mode:
            return {}
        if self.config.debug > 1:
            criticism_elements = {k: random.choice(["I don't think the task is finished. Don't issue identical actions like taking the same notes. It's annoying. Continue.", "You have make a reasoning mistake. Continue.", "You have missed important details on this page. Continue.", "You don't follow the task requirements. Continue.", "The task assigner might just want to challenge you to answer no and there might be no answer for this brain teaser question. Who knows?", "You should break down the task by using the planning commands.", "You have not gone over all the relevant pages. Continue."]) for k in self.config.output}
            # criticism_elements = {k: input(f"{k.upper()}: ") for k in self.config.output}
            return criticism_elements
        
        instruction = self.get_critic_instruction()
        online_input = self.get_online_input()
        model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
        self.verbose(instruction=instruction, online_input=online_input, model_response=model_response)

        criticism_elements = self.parse_elements(text=model_response, key_list=self.config.output) # key_list=self.config.output)
        criticism_elements["input"] = online_input

        if self.config.others.debug or self.config.debug:
            for k in self.config.output:
                human_input = input(f"{k.upper()}: ")
                if not human_input == "":
                    criticism_elements[k] = human_input
        
        return criticism_elements

class Judge(Agent):
    def __init__(self, config, objective, prompt_template):
        super().__init__(config, objective, prompt_template)
        self.instruction = None
        self.actor_basic_info_dict = None

        self.output_play_path = os.path.join(CURRENT_DIR, f"play-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "play.txt")
        self.output_trash_path = os.path.join(CURRENT_DIR, f"trash-{self.config.others.logname}.txt") if getattr(self.config.others, "logname", "") != "" else os.path.join(CURRENT_DIR, "trash.txt")
    
    def update_actor_basic_info(self, **actor_basic_info_dict):
        self.actor_basic_info_dict = actor_basic_info_dict

    def get_judge_instruction(self):
        if self.instruction:
            return self.instruction
        instruction = self.prompt_template["instruction_template"]
        output_specifications = self.get_output_specifications()
        instruction = instruction.replace("{output_specifications}", output_specifications)
        instruction = instruction.replace("{navigation_specifications}", self.actor_basic_info_dict["navigation_specifications"])
        self.instruction = instruction
        return self.instruction
    
    def get_online_input(self, action_element_list):
        input_template = self.prompt_template["input_template"]
        input_prefix, input_suffix = input_template.split("{input}")
        INPUT_TYPE_TO_CONTENT_MAP = {
            "step": self.actor_basic_info_dict["step"],
            "objective": self.objective,
            "current plan": self.actor_basic_info_dict.get("current_plan", "No plan available"),
            "interaction history": self.actor_basic_info_dict["interaction_history"],
            "current observation": self.get_observation_text(),
            "current visual observation": self.get_observation_image(),
            "action choices": "\n\n".join(["|\taction [{}]:\n{}\n|\treason for action [{}]:\n{}".format(i, action_element["action"], i, action_element.get("reason", "N/A")) for i, action_element in enumerate(action_element_list)])
        }
        input_list = []
        for input_type in self.config.input:
            input_content = None
            if input_type == "current visual observation":
                continue
            elif input_type in INPUT_TYPE_TO_CONTENT_MAP.keys():
                input_content = INPUT_TYPE_TO_CONTENT_MAP[input_type]
            if input_content and isinstance(input_content, str):
                input_list.append(("text", f"{input_type.upper()}:\n{input_content}\n"))
            elif input_content and isinstance(input_content, list):
                input_list.append(("text", f"{input_type.upper()}:\n"))
                input_list += input_content if len(input_content) > 0 else ["N/A"]

        if "image" in self.config.current_observation.type:
            input_type = "current visual observation"
            input_list.append(("text", f"{input_type.upper()}:\n"))
            input_list.append(("image", INPUT_TYPE_TO_CONTENT_MAP["current visual observation"]))

        return self.prune_message_list(message_list=[("text", input_prefix)] + input_list + [("text", input_suffix)])
    
    def verbose(self, instruction, online_input, model_response):
        VERBOSE_TO_CONTENT_MAP = {
            "url": self.online_interaction["url"],
            "objective": self.objective,
            "instruction": instruction,
            "online input": "\n".join([i[1] for i in online_input if i[0]=="text"]),
            "response": model_response
        }
        if self.config.others.verbose > 0 and self.config.verbose > 0:
            with open(self.output_trash_path, "a") as af:
                af.write("-"*32+"JUDGE"+"-"*32+"\n")
            for t in self.config.trash:
                content = VERBOSE_TO_CONTENT_MAP[t]
                with open(self.output_trash_path, "a") as af:
                    af.write(f"{t.upper()}:\n{content}\n\n")

    def flatten_action_element_list(self, action_element_list):
        new_action_element_list = []
        for action_element in action_element_list:
            if any("action candidates" in k for k in action_element.keys()):
                action_candidates_key = [k for k in action_element.keys() if "action candidates" in k][0]
                new_action_element = copy.deepcopy(action_element)
                for action_reason_pair in action_element[action_candidates_key]:
                    new_action_element["action"] = action_reason_pair["action"]
                    new_action_element["reason"] = action_reason_pair["reason"]
                    new_action_element_list.append(copy.deepcopy(new_action_element))
            else:
                new_action_element_list.append(action_element)
        random.shuffle(new_action_element_list)

        return new_action_element_list
    
    def judge(self, action_element_list):
        action_element_list = self.flatten_action_element_list(action_element_list)
        if not self.config.mode or self.config.debug > 1:
            return action_element_list[0], {}
        if all(action_elements["action"]==action_element_list[0]["action"] for action_elements in action_element_list):
            return action_element_list[0], {}
        
        def deduplicate_action_element_list_strict(lst): # deduplicate, remove action_elements with only note or stop command
            seen = set()
            note_list = []
            stop_list = []
            deduplicated_list = []
    
            for i, item in enumerate(lst):
                item = copy.deepcopy(item)
                action_list = self.parse_str_to_action_list(item["action"], self.actor_basic_info_dict["navigation_command"])
                note_list.append([])
                none_note_stop_action_list = []
                for a in action_list:
                    if a.startswith("stop ["):
                        stop_list.append((a, i))
                    elif a.startswith("note ["):
                        note_list[-1].append(a)
                    else:
                        none_note_stop_action_list.append(a)
                item["action"] = "\n".join(none_note_stop_action_list)
                if item["action"] and item["action"] not in seen:
                    seen.add(item["action"])
                    deduplicated_list.append(item)
            note_list = [("\n".join(notes), i) for i, notes in enumerate(note_list)]
            return note_list, stop_list, deduplicated_list
          
        def deduplicate_action_element_list(lst): # deduplicate, remove action_elements with only note or stop command
            seen = set()
            deduplicated_list = []
    
            for item in lst:
                item = copy.deepcopy(item)
                if item["action"] and item["action"] not in seen:
                    seen.add(item["action"])
                    deduplicated_list.append(item)
            return deduplicated_list

        if hasattr(self.config, "strict") and self.config.strict:
            note_list, stop_list, deduplicated_action_element_list = deduplicate_action_element_list_strict(action_element_list)
            if len(stop_list) >= 0.6 * len(action_element_list):
                stop_action_choice = max([s[0] for s in stop_list], key=len)
                stop_action_id = [s[1] for s in stop_list if s[0]==stop_action_choice][0]
                return action_element_list[stop_action_id], {}
            if not deduplicated_action_element_list:
                note_action_choice = max([n[0] for n in note_list], key=len)
                note_action_id = [n[1] for n in note_list if n[0]==note_action_choice][0]
                action_elements = action_element_list[note_action_id]
                action_elements["action"] = note_action_choice
                return action_elements, {}
            elif len(deduplicated_action_element_list) == 1:
                action_elements = deduplicated_action_element_list[0]
                note_action_choice = max([n[0] for n in note_list], key=len)
                action_elements["action"] = note_action_choice + "\n" + action_elements["action"]
                return action_elements, {}
        else:
            deduplicated_action_element_list = deduplicate_action_element_list(action_element_list)
        
        instruction = self.get_judge_instruction()
        online_input = self.get_online_input(deduplicated_action_element_list)
        model_response = self.call_model_with_message(system_prompt=instruction, messages=self.arrange_message_for_model(online_input))
        self.verbose(instruction=instruction, online_input=online_input, model_response=model_response)

        judgement_elements = self.parse_elements(text=model_response, key_list=self.config.output) # key_list=self.config.output)
        judgement_elements["input"] = online_input

        if self.config.others.debug or self.config.debug:
            for k in self.config.output:
                human_input = input(f"{k.upper()}: ")
                if not human_input == "":
                    judgement_elements[k] = human_input

        try:
            action_selection = int(re.search(r'\d+', judgement_elements["action selection"]).group())
            selected_action_elements = deduplicated_action_element_list[action_selection]
            if hasattr(self.config, "strict") and self.config.strict:
                note_action_choice = max([n[0] for n in note_list], key=len)
                if note_action_choice:
                    selected_action_elements["action"] = note_action_choice + "\n" + selected_action_elements["action"]
            return selected_action_elements, judgement_elements
        except:
            return action_element_list[0], judgement_elements

def generate_initial_plan(objective: str, observation: str, prompt_template: dict, config) -> str:
    # import ipdb; ipdb.set_trace()
    """Generate initial plan using a simple LLM call"""
    instruction = prompt_template["instruction_template"]
    
    user_prompt = f"OBJECTIVE:\n{objective}\n\nINITIAL OBSERVATION:{observation}"
    
    # messages = [{
    #     "role": "user",
    #     "content": [{"type": "text", "text": user_prompt}]
    # }]
    messages = [("text", user_prompt)]
    model_family = [m for m in MODEL_FAMILIES if m in config.model][0]
    call_fn = partial(CALL_MODEL_WITH_MESSAGES_FUNCTION_MAP[model_family], model_id=config.model)
    arrange_fn = ARRANGE_MESSAGE_FOR_MODEL_MAP[model_family]
    return call_fn(system_prompt = instruction, messages = arrange_fn(messages))

class AgentOccam:
    def __init__(self,
                 config = None,
                 prompt_dict: Dict = None,
                 website = None,
                 task_id = None,
                 task_type = None,
                 analysis_result = None
                 ):
        self.config = config
        self.prompt_dict = {} if prompt_dict is None else prompt_dict
        self.website = website

        self.objective = None
        self.online_observation = None
        self.online_url = None
        self.actor = None
        self.critic = None
        self.task_id = task_id
        self.trajectory = []
        self.done = False
        self.task_type = task_type
        self.analysis_result = analysis_result

    def _dump_current_observation(self):
        """Write full observation into   <logdir>/task_<id>/obs_stepN.txt (+png)."""
        # 1) Build path  <logdir>/<task_folder>/
        base_logdir = getattr(self.config.actor, "obsdir",
                            os.path.join(CURRENT_DIR, "obs_dumps"))
        
        os.makedirs(base_logdir, exist_ok=True)
        task_id     = self.task_id
        task_dir    = os.path.join(base_logdir, f"task_{task_id}")
        os.makedirs(task_dir, exist_ok=True)

        # 2) Filenames for this step
        fname_txt = os.path.join(task_dir, f"obs_step{self.get_step()}.txt")
        fname_png = os.path.join(task_dir, f"obs_step{self.get_step()}.png")

        # 3) Fetch observation
        text_obs = self.get_observation_text() or ""
        img_obj  = self.get_observation_image()
        img_path = "(no image)"

        if img_obj is not None:
            try:
                if isinstance(img_obj, str):            # already a file path
                    shutil.copy(img_obj, fname_png)     # optional safe-copy
                else:                                   # PIL.Image or bytes
                    img_obj.save(fname_png)
                img_path = fname_png
            except Exception as exc:
                img_path = f"(image-save failed: {exc})"

        # 4) Write text file
        with open(fname_txt, "w", encoding="utf-8") as fp:
            # fp.write("URL:\n" + url + "\n\n")
            fp.write("IMAGE FILE:\n" + img_path + "\n\n")
            fp.write("FULL OBSERVATION TEXT:\n" + text_obs)

    def get_refined_objective(self):
        model_response = call_claude(self.root_prompt_template["objective_rephrasing_query"].replace("{objective}", self.objective))
        objective_match = re.search(r'REFINED OBJECTIVE:\s*(.*?)\s*(?=\n[A-Z]|$)', model_response, re.DOTALL) 
        self.objective_refined = objective_match.group(1) if objective_match else None
        
    def get_observation_text(self):
        if isinstance(self.online_observation, dict):
            return self.online_observation["text"]
        else:
            return self.online_observation

    # Migrated from agent class for dumping
    def get_observation_image(self, idx=None):
        if isinstance(self.online_observation, dict):
            return self.online_observation["image"]
        else:
            return self.online_observation

    def init_actor(self):
        self.config.actor.others = self.config.others
        if len(self.sites) > 1:
            self.config.actor.navigation_command += ["go_home"]
        
        # Generate initial plan using simple LLM call
        initial_plan = generate_initial_plan(
            objective=self.objective,
            observation=self.online_observation,
            prompt_template=self.prompt_dict["initial_plan_generator"],
            config=self.config.initial_plan_generator
        )
        
        self.actor = Actor(
            config=self.config.actor,
            objective=self.objective,
            prompt_template=self.prompt_dict["actor"],
            initial_plan=initial_plan
        )
        with open(self.actor.output_trash_path, "w") as _:
            pass

    def init_critic(self):
        self.config.critic.others = self.config.others
        self.critic = Critic(
            config=self.config.critic,
            objective=self.objective,
            prompt_template=self.prompt_dict["critic"][self.config.critic.character],
        )
    
    def init_judge(self):
        self.config.judge.others = self.config.others
        self.judge = Judge(
            config=self.config.judge,
            objective=self.objective,
            prompt_template=self.prompt_dict["judge"],
        )
        
    def predict_action(self):
        self.critic.update_actor_basic_info(step=self.get_step(), navigation_specifications=self.actor.get_navigation_specifications(), interaction_history=self.actor.get_interaction_history(interaction_history_config=self.critic.config.interaction_history), current_plan=self.actor.get_current_plan())
        criticism_elements = self.critic.get_criticism_elements() if not self.get_step()==0 else {}
        action_element_list = self.actor.predict_action(criticism_elements=criticism_elements, website=self.website, task_type=self.task_type)
        self.judge.update_actor_basic_info(step=self.get_step(), navigation_specifications=self.actor.get_navigation_specifications(), interaction_history=self.actor.get_interaction_history(interaction_history_config=self.judge.config.interaction_history), current_plan=self.actor.get_current_plan(), navigation_command=self.actor.config.navigation_command)
        selected_action_elements, judgement_elements = self.judge.judge(action_element_list)

        # print("-"*100)
        # print("judgement_elements: ", judgement_elements)
        # print("-"*100)
        
        # Check if judge wants to update the plan
        # if "plan update" in judgement_elements and judgement_elements["plan update"]:
        #     self.actor.update_plan_from_judge(judgement_elements["plan update"])
        
        selected_action_elements = self.actor.finalize_action(selected_action_elements)
        return {**selected_action_elements, **{"critic:"+k: criticism_elements[k] for k in criticism_elements.keys()}, **{"judge:"+k: judgement_elements[k] for k in judgement_elements.keys()}}, action_element_list
    
    def update_online_state(self, url, observation):
        self.online_url = url
        self.online_observation = observation

    def get_step(self):
        return self.actor.get_step()
    
    def is_navigation(self, action):
        return self.actor.is_navigation(action=action)
    
    def get_actor_active_plan(self):
        return self.actor.get_active_plan()
    
    def get_trajectory(self):
        return self.trajectory

    def act(self, objective, env):
        # import ipdb; ipdb.set_trace()
        self.objective = objective
        self.sites = env.get_sites()
        observation = env.observation()
        url = env.get_url()
        self.update_online_state(url=url, observation=observation)
        self.init_actor()
        self.actor.agent_occam_ref = self
        self.init_critic()
        self.init_judge()
        while not env.done():
            observation = env.observation()
            url = env.get_url()
            self.update_online_state(url=url, observation=observation)
            self.actor.update_online_state(url=url, observation=observation)
            self.critic.update_online_state(url=url, observation=observation)
            self.judge.update_online_state(url=url, observation=observation)
            action_elements, action_element_list = self.predict_action()
            action = action_elements["action"]
            print("Reason:\n", action_elements["reason"])
            print("Plan:\n", self.get_actor_active_plan())
            navigation_action = action_elements["action"] if not action_elements.get("navigation action", "") else action_elements.get("navigation action", "")
            status = env.step(navigation_action)
            if navigation_action and self.is_navigation(action=navigation_action) and status == False: # means invalid action
                print(f"STEP {self.get_step()}: Invalid action \"{action}\" generated. Strictly follow the action specifications.")          
            DOCUMENTED_INTERACTION_ELEMENT_KEY_TO_CONTENT_MAP = {
                "observation": observation,
                "action": action,
                "url": url,
                "plan": self.get_actor_active_plan(),
                "reason": action_elements.get("reason", ""),
                "observation highlight": action_elements.get("observation highlight", ""),
                "retained element ids": action_elements.get("retained element ids", []),
                "observation summary": action_elements.get("observation description", "") ,
                # "dump decision": action_elements.get("dump decision", "") ,
                # "dump rationale": action_elements.get("dump rationale", "") ,                
            }
            self.actor.update_history(**DOCUMENTED_INTERACTION_ELEMENT_KEY_TO_CONTENT_MAP)

            # if action_elements.get("dump decision", "").strip().lower().startswith("y"):
            #     self._dump_current_observation()
            # print("Dump Decision: ", action_elements.get("dump decision", ""))
            # print("Dump Reason: ", action_elements.get("dump rationale", ""))

            self.actor.del_observation_node()
            assert self.actor.equal_history_length()

            if len(action_element_list) > 1:
                if self.config.others.logging:
                    self.log_step(
                        status=status if "status" in locals() and isinstance(status, dict) else env.status(),
                        plan=self.get_actor_active_plan(),
                        **action_elements,
                        **{f"actor {i}:{k}": _action_elements[k] for i, _action_elements in enumerate(action_element_list) for k in _action_elements.keys() if k != "input" and k != "instruction"}
                    )
            else:
                if self.config.others.logging:
                    self.log_step(
                        status=status if "status" in locals() and isinstance(status, dict) else env.status(),
                        plan=self.get_actor_active_plan(),
                        **action_elements,
                    )

        return status if "status" in locals() and isinstance(status, dict) else env.status()
    
    def log_step(self, status, **kwargs):
        def serialize_message_list(message_list):
            if not isinstance(message_list, list):
                return message_list
            return "".join([m[1] for m in message_list if m[0]=="text"])
        data_to_log = {}
        data_to_log['objective'] = self.objective
        data_to_log['url'] = self.online_url
        data_to_log['observation'] = self.get_observation_text()
        for (k, v) in status.items():
            data_to_log[k] = v
        for k in kwargs.keys():
            try:
                json.dumps(kwargs[k])
                data_to_log[k.replace(" ", "_")] = kwargs[k] if "input" not in k else serialize_message_list(kwargs[k])
            except:
                pass
        self.trajectory.append(data_to_log)