import argparse
import ast
import asyncio
import logging
import re
import typing
import functools
from langgraph.checkpoint.memory import MemorySaver
from dotenv import load_dotenv
from langchain_core.messages import BaseMessage, HumanMessage, AIMessage, ChatMessage
from langchain.memory import ConversationBufferMemory
from langgraph.graph import END, StateGraph, START
from agentS import utils
from agentS.prompts import RANKER_PROMPT, ACTION_AGENT_PROMPT, PLANNER_AGENT_PROMPT, \
    SYSTEM_MESSAGE_PLANNER_AGENT, ENHANCED_PROMPTS

from agentS.utils import process_action_history, MemoryAgent, process_page_understanding
from agentS.consts import DEFAULT_WEB, OPENAI_MODEL_NAME_GPT_4O_MINI, OPENENDED_TASK, WEBARENA_TASK
from agentS.agent_state import AgentState
from langgraph.prebuilt import ToolNode

from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html


class AgentsManager:
    def __init__(self, llm_name=OPENAI_MODEL_NAME_GPT_4O_MINI, llm_type='genai', llm=None, members=None,
                 system_prompt=None, architecture='general', sync=False,
                 args=None, flags=None, max_retry=None, tools=None, env_policies=None):
        load_dotenv()
        self.llm_type = llm_type
        # self.members = MEMBERS if not members else members
        # self.system_prompt = SYSTEM_PROMPT.format(
        #     members=", ".join(self.members)) if not system_prompt else system_prompt
        self.sync = sync
        self.architecture = architecture
        self.llm = utils.initialize_llm(llm_name=llm_name)
        self.memory = ConversationBufferMemory()
        self.env_policies = env_policies

        if not tools:
            self.tools = utils.setup_tools() if not sync else utils.setup_sync_tools(architecture=architecture,
                                                                                     env_policies=self.env_policies)
        else:
            self.tools = tools
        self.tool_node = ToolNode([tool_func for tool, tool_func in self.tools.items()])

        # self.options = FINISH_WORK + self.members
        self.memory = ConversationBufferMemory()
        self.create_agents()
        self.construct_graph(architecture="general" if not architecture else architecture)

    def mark_page_sync(self, state: AgentState):
        # screenshot = page.screenshot()
        # annotate_page = await self.extention_obj.analyze_connected_page_async(state)
        # annotate_page = state['annotations']
        # elements = annotate_page[0]['browser_content']  # extract PU description
        # url = state.get("page").url
        # state['annotations'] = annotate_page

        # obs = state.get("annotations").copy()
        # state["dom_txt"] = flatten_dom_to_str(
        #     obs["dom_object"],
        #     with_visible=self.flags.extract_visible_tag,
        #     with_center_coords=self.flags.extract_coords == "center",
        #     with_bounding_box_coords=self.flags.extract_coords == "box",
        #     filter_visible_only=self.flags.extract_visible_elements_only,
        # )
        # state["axtree_txt"] = flatten_axtree_to_str(
        #     obs["axtree_object"],
        #     with_visible=self.flags.extract_visible_tag,
        #     with_center_coords=self.flags.extract_coords == "center",
        #     with_bounding_box_coords=self.flags.extract_coords == "box",
        #     filter_visible_only=self.flags.extract_visible_elements_only,
        # )
        # state["pruned_html"] = prune_html(obs["dom_txt"])
        pointer_env = state.get('pointer_env', state.get('env'))
        state['annotations'] = pointer_env.obs
        state['read_page'] = state.get('annotations').get('read_page', '')
        url = state.get("page").url
        # elements = flatten_axtree_to_str(state.get("annotations")['axtree_object'])

        if 'nocodeui_pu' in state.get("annotations"):
            elements = process_page_understanding(state.get("annotations")['nocodeui_pu'].map)

        elif 'extra_element_properties' in state.get("annotations"):
            extra_element_properties = state.get("annotations")['extra_element_properties']
            elements = flatten_axtree_to_str(
                AX_tree=state.get("annotations")['axtree_object'],
                extra_properties=extra_element_properties,
                with_visible=True,
                with_clickable=True,
                filter_visible_only=True,
                filter_with_bid_only=True,
                remove_redundant_static_text=True,
                skip_generic=True,
                hide_bid_if_invisible=True
            )

        else:
            elements = flatten_axtree_to_str(state.get("annotations")['axtree_object'])

        state['url'] = url
        state['elements'] = elements
        return {**state}

        # return {
        #     # "img": base64.b64encode(screenshot).decode(),
        #     **state,
        #     "elements": elements,
        #     # "annotations": annotate_page,
        #     "url": url,
        # }

    def create_agents(self):

        agent_node = utils.agent_node if not self.sync else utils.sync_agent_node

        research_agent = utils.create_agent(
            prompt=ENHANCED_PROMPTS['RANKER_PROMPT'],
            llm=self.llm,
            tools=[],
            system_message="You should provide accurate elements for the ActionAgent (action agent) to use.",
            sync=self.sync,
            perform_action=self.mark_page_sync if self.sync else self.mark_page_async,
            agent_name="RankerAgent",
            env_policies=self.env_policies,
        )
        self.research_node = functools.partial(agent_node, agent=research_agent, name="RankerAgent")

        # chart_generator
        action_agent = utils.create_agent(
            prompt=ENHANCED_PROMPTS['ACTION_AGENT_PROMPT'],
            llm=self.llm,
            tools=self.tools,
            system_message="Your task is to determine the next action to take.",
            sync=self.sync,
            agent_name="ActionAgent",
            env_policies=self.env_policies,
        )
        self.action_agent_node = functools.partial(agent_node, agent=action_agent, name="ActionAgent")

        planner_agent = utils.create_agent(
            prompt=ENHANCED_PROMPTS['PLANNER_AGENT_PROMPT'],
            llm=self.llm,
            tools=self.tools,
            system_message=SYSTEM_MESSAGE_PLANNER_AGENT,
            sync=self.sync,
            agent_name="PlannerAgent",
            env_policies=self.env_policies,
        )
        self.planner_node = functools.partial(agent_node, agent=planner_agent, name="PlannerAgent")

    def construct_graph(self, architecture='general'):
        if architecture == 'general':
            self.construct_graph_general_policy()
        elif 'dynamic_policy' in architecture:
            self.construct_graph_dynamic_policy()

    def construct_graph_general_policy(self):
        workflow = StateGraph(AgentState)

        workflow.add_node("RankerAgent", self.research_node)
        workflow.add_node("ActionAgent", self.action_agent_node)
        # workflow.add_node("call_tool", self.tool_node)
        workflow.add_node("MemoryAgent", self.update_scratchpad)
        # workflow.add_node("PlannerAgent", self.planner_node)

        workflow.add_conditional_edges(
            "RankerAgent",
            self.router,
            {"continue": "ActionAgent", "__end__": END, 'update_scratchpad': 'MemoryAgent'},
        )
        workflow.add_conditional_edges(
            "ActionAgent",
            self.router,
            {"continue": "RankerAgent", "__end__": END, 'update_scratchpad': 'MemoryAgent'},
        )

        workflow.add_edge("MemoryAgent", "RankerAgent")

        workflow.add_conditional_edges(
            "ActionAgent",
            # Each agent node updates the 'sender' field
            # the tool calling node does not, meaning
            # this edge will route back to the original agent
            # who invoked the tool
            lambda x: x["sender"],
            {
                # "RankerAgent": "MemoryAgent",
                "ActionAgent": "MemoryAgent",
                "RankerAgent": "MemoryAgent",
                "call_tool": "MemoryAgent",
            },
        )
        workflow.add_edge(START, "RankerAgent")

        memory = MemorySaver()

        # self.graph = workflow.compile(checkpointer=memory, interrupt_before=["MemoryAgent"])
        self.graph = workflow.compile()

    def construct_graph_dynamic_policy(self):
        workflow = StateGraph(AgentState)

        workflow.add_node("RankerAgent", self.research_node)
        workflow.add_node("ActionAgent", self.action_agent_node)
        workflow.add_node("MemoryAgent", self.update_scratchpad)
        workflow.add_node("PlannerAgent", self.planner_node)

        workflow.add_edge(START, "PlannerAgent")

        workflow.add_conditional_edges(
            "PlannerAgent",
            self.router_dynamic_policy,
            {
                "continue": "MemoryAgent",
                "update_scratchpad": "MemoryAgent",
                'replan': "ActionAgent",
                "__end__": END
            }
        )

        workflow.add_conditional_edges(
            "RankerAgent",
            self.router_dynamic_policy,
            {
                "continue": "ActionAgent",
                'update_scratchpad': 'MemoryAgent',
                "__end__": END
            }
        )

        workflow.add_conditional_edges(
            "ActionAgent",
            self.router_dynamic_policy,
            {
                "continue": "RankerAgent",
                "update_scratchpad": "MemoryAgent",
                "replan": "PlannerAgent",
                "__end__": END
            }
        )

        workflow.add_conditional_edges(
            "MemoryAgent",
            self.router_dynamic_policy,
            {
                "continue": "RankerAgent",
                "__end__": END
            }
        )

        self.graph = workflow.compile()

    @staticmethod
    def router(state) -> typing.Literal["call_tool", "__end__", "continue", 'update_scratchpad']:
        # This is the router
        messages = state["messages"]
        last_message = messages[-1]

        if 'ERROR' in last_message.content:
            return 'update_scratchpad'

        if "FINAL ANSWER" in last_message.content or state.get('sender') == END:
            return "__end__"

        if last_message.tool_calls:
            # For all calls insert the current state
            for call in last_message.tool_calls:

                if not last_message.tool_calls[0]['args']:
                    if call['name'] in {'goback', 'read_page', 'to_google'}:
                        call['args'] = {'state': {'cur_state': state}}
                    else:
                        print(f"Tool call args: {last_message.tool_calls[0]['args']} without args")
                        return "continue"
                else:
                    try:
                        if 'state' not in call['args']:
                            # Insert the current state into the tool call
                            call['args'] = {'state': call['args']}
                        call['args']['state']['cur_state'] = state
                    except TypeError:
                        print(f"Tool call args: {last_message.tool_calls[0]['args']} without state")
                        return "continue"  # Go back to the agents
            # return "call_tool"
            return "update_scratchpad"

        return "continue"

    @staticmethod
    def update_policy(state: AgentState, new_policy: str, reason: str) -> None:
        state['policy'] = new_policy
        state['update_policy_reason'] = reason

    @staticmethod
    def router_dynamic_policy(state: AgentState) -> typing.Literal[
        "call_tool", "__end__", "continue", "update_scratchpad", "replan", "return_to_action"]:
        messages = state["messages"]
        last_message = messages[-1]

        if 'ERROR' in last_message.content:
            return 'update_scratchpad'

        if "FINAL ANSWER" in last_message.content or state.get('sender') == END:
            return END

        if last_message.tool_calls:
            for call in last_message.tool_calls:
                if not call.get('args'):
                    if call['name'] in {'goback', 'read_page', 'to_google'}:
                        call['args'] = {'state': {'cur_state': state}}
                    elif call['name'] == 'update_policy':
                        return "replan"
                    elif call['name'] == 'answer':
                        return END

                    else:
                        print(f"Tool call args: {call['args']} without args")
                        return "continue"
                else:
                    try:
                        if 'state' not in call['args']:
                            call['args'] = {'state': call['args']}
                        call['args']['state']['cur_state'] = state
                        # Check for update_policy within args
                        if call['name'] == 'update_policy':
                            return "replan"
                        elif call['name'] == 'answer':
                            return END

                    except TypeError:
                        print(f"Tool call args: {call['args']} without state")
                        return "continue"

            return "update_scratchpad"

        return "continue"

    @staticmethod
    def update_scratchpad(state: AgentState, max_memory_length: int = 200) -> AgentState:
        """Update the scratchpad with a sliding window approach to limit memory length."""

        old_content = state['scratchpad'][0].content if state['scratchpad'] else ""
        lines = old_content.split("\n") if old_content else []

        if not lines:
            lines = ["Previous action observations:"]

        last_step = MemoryAgent.get_last_step(lines)
        step = last_step + 1
        pointer_env = state.get('pointer_env', state.get('env'))

        filtered_obs = process_action_history(pointer_env.feedback)
        pointer_env.feedback = []
        state['observation'] = [AIMessage(content="\n".join(filtered_obs))]

        # Update memory with read page containing the page content
        # if state.get('read_page'):
        #     lines, step = MemoryAgent.update_memory_with_read_page(lines, state['read_page'], step)

        for obs in filtered_obs:
            lines.append(f"{step}. {obs}")
            step += 1

        if len(lines) > max_memory_length + 1:
            lines = lines[:1] + lines[-(max_memory_length):]

        updated_content = "\n".join(lines)
        print(f"Updated sessions memory:\n\n{updated_content}")
        state['messages'] = [AIMessage(content=updated_content)]

        state['scratchpad'] = [AIMessage(content=updated_content)]
        state['sender'] = 'MemoryAgent'
        return state

    async def call_agent_async(self, question: str, max_steps: int = 150):

        if self.architecture == 'parser_agent':
            event_stream = self.graph.astream(
                {
                    "page": self.page,
                    "extension_obj": self.extention_obj,
                    "input": question,
                    "scratchpad": [],

                },
                {"recursion_limit": max_steps},
            )
            async for event in event_stream:
                if '__end__' not in event:
                    if "ActionAgent" in event:
                        if event["ActionAgent"].get("additional_kwargs"):
                            print(event["ActionAgent"].get("additional_kwargs"))
                        else:
                            pred = event["ActionAgent"].get("messages")[-1] or {}
                        if pred.content != "":
                            pred = pred.content
                        print("\n ActionAgent prediction: \n")
                    elif "RankerAgent" in event:
                        pred = event["RankerAgent"].get("elements") or {}
                        print(f"\n RankerAgent prediction:\n")
                    elif "call_tool" in event:
                        pred = "\n".join([msg.content for msg in event["call_tool"].get("messages")]) or {}
                        pred = f"\n Agent performed: \n {pred} \n"
                    elif "MemoryAgent" in event:
                        pred = f"\n MemoryAgent updated memory \n"
                    else:
                        pred = 'prediction'

                    print(pred)
                    print("-" * 50)
                else:
                    print(event)
                    break

    async def call_agent_sync(self, question: str, max_steps: int = 150):

        if self.architecture == 'parser_agent':
            event_stream = self.graph.stream(
                {
                    "page": self.page,
                    "extension_obj": self.extention_obj,
                    "input": question,
                    "scratchpad": [],

                },
                {"recursion_limit": max_steps},
            )
            for event in event_stream:
                if '__end__' not in event:
                    if "ActionAgent" in event:
                        if event["ActionAgent"].get("additional_kwargs"):
                            print(event["ActionAgent"].get("additional_kwargs"))
                        else:
                            pred = event["ActionAgent"].get("messages")[-1] or {}
                        if pred.content != "":
                            pred = pred.content
                        print("\n ActionAgent prediction: \n")
                    elif "RankerAgent" in event:
                        pred = event["RankerAgent"].get("elements") or {}
                        print(f"\n RankerAgent prediction:\n")
                    elif "call_tool" in event:
                        pred = "\n".join([msg.content for msg in event["call_tool"].get("messages")]) or {}
                        pred = f"\n Agent performed: \n {pred} \n"
                    elif "MemoryAgent" in event:
                        pred = f"\n MemoryAgent updated memory \n"
                    else:
                        pred = 'prediction'

                    print(pred)
                    print("-" * 50)
                else:
                    print(event)
                    break


async def main():
    # llm_type = 'genai'  # or 'openai'
    llm_type = 'openai'  # or 'openai'
    sync = False
    architecture = "general"
    architecture = "parser_agent"

    # Initialize agents
    # agents = AgentsManager(llm_type=llm_type)
    agents = AgentsManager(llm_type=llm_type, sync=sync, architecture=architecture)
    # agents.create_supervisor_agent()
    # agents.construct_graph()

    # question = "Could you explain the WebVoyager paper (on arxiv)?"
    data_dict = {
        "First name": ["Samantha", "Benjamin", "Emily", "Michael", "Olivia", "Ethan", "Sophia", "David", "Abigail",
                       "Alexander"],
        "Username": ["Samantha_Sunflower", "Benjamin_Bear", "Emily_Butterfly", "Michael_Mountain", "Olivia_Ocean",
                     "Ethan_Eagle", "Sophia_Sunshine", "David_Dragon", "Abigail_Apple", "Alexander_Asteroid"],
        "Contact Number": ["(555) 555-0123", "(555) 555-4567", "(555) 555-8901", "(555) 555-8102", "(555) 555-8903",
                           "(555) 555-9904", "(555) 555-0905", "(555) 555-8106", "(555) 555-8007", "(555) 555-8208"],
        "Manager": ["John Smith", "Jessica Williams", "David Brown", "Sarah Johnson", "Michael Davis",
                    "Rachel Anderson", "Robert Garcia", "Lisa Martin", "Thomas Rodriguez", "Jennifer Thompson"],
        "Email": ["SamanthaSmith@test.com", "BenjaminWilliams@test.com", "EmilyBrown@test.com",
                  "MichaelJohnson@test.com", "OliviaDavis@test.com", "EthanAnderson@test.com", "SophiaGarcia@test.com",
                  "DavidMartin@test.com", "AbigailRodriguez@test.com", "AlexanderThompson@test.com"],
        "Job Description": ["Software Engineer", "Data Scientist", "Technical Project Manager", "DevOps Engineer",
                            "Security Engineer", "Technical Writer", "Quality Assurance Engineer",
                            "Front-End Developer", "Network Engineer", "Cloud Architect"],
        "Level": [5, 6, 7, 8, 9, 10, 11, 12, 13, 14],
        "Reason": ["Poor performance", "", "Misconduct", "", "Detrimental behavior", "", "Position no longer necessary",
                   "", "", "Violation of company policies"]
    }
    url = "https://ipachallenge.19pc1vtv090u.us-east.codeengine.appdomain.cloud"
    instructions = """ \n 1. To begin, click on the start button to initiate the round counter. \n 2. Click submit when finished filling the form
     \n  3. Upon each form submission, the following changes will occur:
The field positions will change, rendering them in a different order than the original.
The field labels will change, while retaining semantic equivalence. For instance, the label "username" could be changed to "id."
\n 4. The round counter will stop once it reaches the number of records present in the spreadsheet. At this point, a score will be generated based on the number of fields accurately filled by the bot."""
    question = (
        f"Please submit to the form for the following information: {data_dict},"
        f"if there are several people the form must be submitted for each person alone, You should click start when you are ready to start the process for each person, click submit when done.")

    question = (
        f"Please find me the cheapest hotel in Rome for the dates 2024-12-12 to 2024-12-15, search for the hotel in booking.com")
    url = DEFAULT_WEB
    # url = "http://google.com"
    # url = 'https://arxiv.org/abs/2401.13919'
    # question = f"Could you please explain to me the WebVoyager paper (on arxiv)?"
    question = f"Who are the authors of the IDA: Breaking Barriers in No-code UI Automation Through Large Language Models and Human-Centric Design paper (on arxiv)?"
    # question = f"Which institutions are involved in the the paper IDA: Breaking Barriers in No-code UI Automation Through Large Language Models and Human-Centric Design paper (on arxiv)?"
    # question = "When was Atlit, Israel founded?"
    question = "Was WorkArena++ paper published in a conference or journal?"
    question = "Find me the official web page that explains how to submit US tax reports"
    question = "Find and play a video of Lang Graph library explanation"
    question = "search and log in into Demo OrangeHRM website"
    question = "Was WorkArena paper published in a conference or journal? If so, please provide information about the conference"
    question = "What is best month of the year to buy Apple stock?"
    question = "Find Avi Yaeli that works at IBM and provide me with Information about him"
    question = "Find Avi Yaeli that works at IBM and let me know at which university he studied"
    question = "Find Shila Ofek Koifman that works at IBM and list all of her previous papers"
    question = "Find Shila Ofek Koifman that works at IBM and provide me with Information about her"
    question = "List the roles that Shila Ofek Koifman has at IBM"

    question = "What is the most cited paper by Shila Ofek Koifman, and where does she work?"

    url = r"https://impl.workday.com/wday/authgwy/ibmsrv_dpt1/login.htmld?returnTo=%2fibmsrv_dpt1%2fd%2fhome.htmld"
    question = (
        f"Please log into the IBM Workday system and navigate to the 'Home' page., username: 'lmcneil', password: 'LWWUw5Gw-cDDs0K'"
        f"Then, get into Jhon Davis page and find his employee ID")

    # question = (
    #     f"Please log in to the IBM Workday system and navigate to the 'Home' page., username: 'lmcneil', password: 'LWWUw5Gw-cDDs0K'"
    #     f"Then, get into Jhon Davis page, go to contact section and list me his Home Contact information")

    res = None
    if not sync:
        res = await agents.initialize_web_connection_async(url=url)  # Initialize the web connection asynchronously
        res = await agents.call_agent_async(question=question)

    print(f"Final response: {res}")


if __name__ == "__main__":
    asyncio.run(main())
