import json
import os
import re
import typing
from getpass import getpass
from typing import List, Any, Dict

from langchain_core.messages import ToolMessage, AIMessage
from langchain_openai import ChatOpenAI
from langgraph.graph import END

from agentS.agent_state import AgentState
from agentS.agents.agents_output_handlers import Agents_Output_Handlers
from agentS.agents.agents_prompt_creator import AgentsPromptCreator
# from agentS.thread_safety_tools import IntegratedBrowserTools
# from ibm_nl2ui.planning.agents.tools import Tools
from agentS.tools_manager import Tools, SychronousTools
from agentS.consts import OPENAI_MODEL_NAME_GPT_4O_MINI, LANGCHAIN_API_KEY, OPENAI_API_KEY, HUMAN_IN_THE_LOOP_FUNC_NAME


def initialize_llm(llm_name=OPENAI_MODEL_NAME_GPT_4O_MINI, llm=None):
    llm = llm or ChatOpenAI(model=llm_name)
    for key in [LANGCHAIN_API_KEY, OPENAI_API_KEY]:
        if not os.getenv(key):
            os.environ[key] = getpass(f"{key}=")
    return llm


def setup_tools() -> Dict[str, Any]:
    return {
        "Click": Tools.click,
        "Type": Tools.type,
        # "Scroll": Tools.scroll,
        # "Google": Tools.to_google,
        "ANSWER": Tools.answer,
        "GoBack": Tools.goback,
        "ReadPage": Tools.read_page,
        "WebTaskPlanner": Tools.create_action_plan,
    }


def setup_sync_tools(architecture='general', env_policies='') -> Dict[
    str, Any]:  # The keys are not being used, its just for reference
    if architecture == 'general':
        return {
            "Click": SychronousTools.click,
            "Type": SychronousTools.type,
            # "Scroll": SychronousTools.scroll,
            # "Google": SychronousTools.to_google,
            "ANSWER": SychronousTools.answer,
            "GoBack": SychronousTools.goback,
            "ReadPage": SychronousTools.read_page,
            # "UpdatePolicy": SychronousTools.update_policy,
            "SelectOption": SychronousTools.select_option,
        }
    elif architecture == 'dynamic_policy':
        if env_policies == '':
            return {
                "Click": SychronousTools.click,
                "Type": SychronousTools.type,
                # "Scroll": SychronousTools.scroll,
                # "Google": SychronousTools.to_google,
                "ANSWER": SychronousTools.answer,
                "GoBack": SychronousTools.goback,
                # "ReadPage": SychronousTools.read_page,
                "UpdatePolicy": SychronousTools.update_policy,
                "SelectOption": SychronousTools.select_option,
            }
        else:  # Human in the loop
            return {
                "Click": SychronousTools.click,
                "Type": SychronousTools.type,
                # "Scroll": SychronousTools.scroll,
                # "Google": SychronousTools.to_google,
                "ANSWER": SychronousTools.answer,
                "GoBack": SychronousTools.goback,
                # "ReadPage": SychronousTools.read_page,
                "UpdatePolicy": SychronousTools.update_policy,
                "SelectOption": SychronousTools.select_option,
                "HumanInTheLoop": SychronousTools.human_in_the_loop,
            }
    return {
        "Click": SychronousTools.click,
        "Type": SychronousTools.type,
        # "Scroll": SychronousTools.scroll,
        # "Google": SychronousTools.to_google,
        "ANSWER": SychronousTools.answer,
        "GoBack": SychronousTools.goback,
        "ReadPage": SychronousTools.read_page,
        # "UpdatePolicy": SychronousTools.update_policy,
        "SelectOption": SychronousTools.select_option,
    }


def create_agent(prompt, llm, tools, system_message: str, sync: bool = False, perform_action=None, agent_name=None,
                 env_policies=None):
    """Create an agent."""
    if agent_name == 'RankerAgent':
        prompt = AgentsPromptCreator.create_ranker_prompt(prompt=prompt, env_policies=env_policies)
    elif agent_name == 'ActionAgent':
        prompt = AgentsPromptCreator.create_action_agent_prompt(prompt=prompt, env_policies=env_policies)
    elif agent_name == 'PlannerAgent':
        prompt = AgentsPromptCreator.create_planner_agent_prompt(prompt=prompt, env_policies=env_policies)
    else:
        prompt = AgentsPromptCreator.create_general_prompt(prompt=prompt)

    prompt = prompt.partial(system_message=system_message)

    if isinstance(tools, dict):
        prompt = prompt.partial(tool_names=", ".join([tool.name for key, tool in tools.items()]))
        tools = tools.values()
        pipeline = prompt | llm.bind_tools(tools)

    elif len(tools) == 0:
        pipeline = prompt | llm

    else:
        prompt = prompt.partial(tool_names=", ".join([tool.name for tool in tools]))
        pipeline = prompt | llm.bind_tools(tools)

    if perform_action:
        return perform_action | pipeline

    return pipeline


# Helper function to create a node for a given agent
async def agent_node(state, agent, name):
    result = await agent.ainvoke(state)
    # We convert the agent output into a format that is suitable to append to the global state
    if isinstance(result, ToolMessage):
        pass
    else:
        result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)
        if name == "RankerAgent":
            cleaned_result = result.content.replace('json', "").replace('```', "")
            cleaned_json_data = re.sub(r'//.*', '', cleaned_result)
            try:
                parsed_result = json.loads(cleaned_json_data)
                # To add ranker output filtering to variables to the state and to adjust the prompt
                state['elements'] = parsed_result.get('Relevant Elements', parsed_result.get(list(state.keys())[0]))
            except json.JSONDecodeError:
                print(f'\n Error transforming to JSON: {cleaned_json_data}')
                state['elements'] = cleaned_json_data


        # Verify action agent output and adjust the prompt
        elif name == "ActionAgent":
            if result.tool_calls:
                for call in result.tool_calls:
                    if call["name"] in ["to_google", 'read_page']:
                        continue

                    if not call["args"]:
                        action_state = call["args"]
                    else:
                        action_state = call["args"]["state"]

                    if isinstance(action_state, dict):
                        try:
                            element_predicted = action_state.get(list(action_state.keys())[0])
                        except:
                            print(f'Error element_predicted: {action_state}')
                            element_predicted = action_state
                            break
                        # element_predicted = action_state.get(list(action_state.keys())[0])
                    else:
                        element_predicted = action_state

                    if call["name"] == "answer":
                        result.content = f"FINAL ANSWER \n {element_predicted}"
                        name = END
                        continue

                    if not isinstance(element_predicted, int):
                        if 'FINAL ANSWER' in element_predicted:
                            state['observation'] = f"{element_predicted}"
                            name = END
                            result.content = element_predicted
                        else:
                            state[
                                'observation'] = [
                                f"ERROR: You have predicted to perform {call['name']} of {element_predicted} as the next action. Please provide an value of the element ID."]
                            name = "call_tool"
                            pointer_env = state.get('pointer_env', None)
                            pointer_env.feedback.append(state['observation'])
                            result = AIMessage(content=state['observation'], name=name)
                        continue

                # Verify that there is only one call to click function, if not take the first one
                for i, call in enumerate(result.tool_calls):
                    if call['name'] == 'answer':
                        result.tool_calls = call  # Return only the answer call
                        break
                    if call["name"] == "click":
                        result.tool_calls = result.tool_calls[:i + 1]
                        break

    return {
        **state,
        "messages": [result],
        # Since we have a strict workflow, we can
        # track the sender, so we know who to pass to next.
        "sender": name,
    }


def sync_agent_node(state, agent, name):
    result = agent.invoke(state)

    if isinstance(result, ToolMessage):
        return {**state, "messages": [result], "sender": name}

    result = AIMessage(**result.dict(exclude={"type", "name"}), name=name)

    if name == "RankerAgent":
        return Agents_Output_Handlers.handle_ranker_agent(state, result, name)
    elif name == "ActionAgent":
        return Agents_Output_Handlers.handle_action_agent(state, result, name)
    elif name == "PlannerAgent":
        return Agents_Output_Handlers.handle_planner_agent(state, result, name)
    else:
        return {**state, "messages": [result], "sender": name}


tools_by_name = {tool.name: tool for tool_key, tool in setup_tools().items()}


def remove_non_ascii(text):
    """
    Removes any non-ASCII characters and Unicode characters from a string.

    Args:
        text: The string to remove non-ASCII characters from.

    Returns:
        The string with only ASCII characters.
    """
    ascii_text = re.sub(r'[^\x00-\x7F]+', ' ', text)
    return clean_extra_space(ascii_text)


def clean_extra_space(text: str):
    return re.sub(r'\s+', ' ', text.strip())


def process_page_understanding(filtered_elements: List[str]) -> List[str]:
    text_out: list[str] = []
    for element_id, element in filtered_elements.items():
        element_name = remove_non_ascii(element.text.value)
        element_type = element.match.rule.type
        text_out.append(["id=" + element_id + "; type=" + str(element_type) + "; text=" + str(element_name)])

    return '\n'.join(''.join(text) for text in text_out)


def process_action_history(feedback_list: List[str]) -> List[str]:
    """
    Process a list of action feedback strings and convert them to natural language descriptions.

    Args:
        feedback_list (List[str]): List of feedback strings collected during actions.

    Returns:
        List[str]: A list of natural language descriptions of the actions.
    """
    return [
        feedback_to_natural_language(parse_feedback(feedback))
        for feedback in feedback_list
        if feedback.startswith("Action_feedback: ") or feedback.startswith("ERROR: ")
    ]


def parse_feedback(feedback: str) -> Dict[str, Any]:
    """
    Parse a single feedback string into a dictionary.

    Args:
        feedback (str): A feedback string starting with "Action_feedback: ".

    Returns:
        Dict[str, Any]: Parsed feedback as a dictionary.
    """
    if feedback.startswith("ERROR: "):
        try:
            error_content = json.loads(feedback[7:])
            return error_content
        except json.JSONDecodeError:
            return {"error": feedback[7:], "action": "unknown", "element_id": "unknown"}

    elif feedback.startswith("Action_feedback: "):
        try:
            return json.loads(feedback.removeprefix("Action_feedback: "))
        except json.JSONDecodeError:
            return {}  # Return an empty dict for invalid JSON
    else:
        return {}


def feedback_to_natural_language(feedback: Dict[str, Any]) -> str:
    """
    Convert a feedback dictionary to a natural language description.

    Args:
        feedback (Dict[str, Any]): A dictionary containing action feedback.

    Returns:
        str: A natural language description of the action.
    """
    if not feedback:
        return "Invalid or empty feedback received."

    action = feedback.get("action", "unknown action")
    status = feedback.get("status", "unknown status")
    message = feedback.get("value", "")

    if "error" in feedback or status == "error":
        print(feedback)
        error_message = feedback.get("error", "An unspecified error occurred")
        element_id = feedback.get('element_id', 'unknown element')
        error_description = f"Error: tried to perform a {action} action"
        if action in HUMAN_IN_THE_LOOP_FUNC_NAME:
            error_description += f" with the message: '{message}', but encountered an error: {error_message}"
        else:
            error_description += f" on {element_id}, but encountered an error: {error_message}"
        return f"{error_description} Please adjust your behavior accordingly."

    # Handle User_response action
    if action == "User_response":
        return f"Received user response: {message}"

    description = f"Action '{action}' {status}"

    # Handle other human-in-the-loop actions
    if action in HUMAN_IN_THE_LOOP_FUNC_NAME:
        return f"Performed action human in the loop and applied {action} with the message: {message}"

    # Add element information if available
    if "element_id" in feedback:
        description += f" on element '{feedback['element_id']}'"
    if "element_name" in feedback:
        description += f" (element name: {feedback['element_name']})"

    # Add additional details
    details = []
    detail_handlers = {
        "value": lambda v: f"with value '{v}'",
        "options": lambda o: f"selecting options {o}",
        "button": lambda b: f"using {b} button",
        "modifiers": lambda m: f"with modifiers {m}" if m else None,
        "error": lambda e: f"Error: {e}",
        "content": lambda c: f"Content: {c}",
        "update_reason": lambda r: f"Update reason: {r}",
        "retry_with_force": lambda r: "retry with force was used" if r else None,
        "execution_time": lambda t: f"Execution time: {t:.2f} seconds",
    }

    for key, handler in detail_handlers.items():
        if key in feedback and key != "value":  # Skip 'value' as it's already handled
            detail = handler(feedback[key])
            if detail:
                details.append(detail)

    if details:
        description += ". " + " ".join(details)

    return description.strip() + "."


async def tool_node(state: dict):
    result = []
    for tool_call in state["messages"][-1].tool_calls:
        tool = tools_by_name[tool_call["name"]]
        observation = tool.invoke(tool_call["args"])
        result.append(ToolMessage(content=observation, tool_call_id=tool_call["id"]))
    return {"messages": result}


class MemoryAgent:

    @staticmethod
    def extract_step_number(line: str) -> int:
        match = re.match(r"(\d+)", line)
        return int(match.group(1)) if match else 0

    @staticmethod
    def get_last_step(lines: typing.List[str]) -> int:
        for line in reversed(lines):
            step = MemoryAgent.extract_step_number(line)
            if step > 0:
                return step
        return 0

    @staticmethod
    def indent_text(text: str, indent: int = 2) -> str:
        lines = text.split('\n')
        indented_lines = [' ' * indent + line if line.strip() else line for line in lines]
        return '\n'.join(indented_lines)

    @staticmethod
    def format_step_content(step: int, content: str) -> str:
        lines = content.split('\n')
        header = "{0}. {1}".format(step, lines[0])
        body = '\n'.join(lines[1:])
        indented_body = MemoryAgent.indent_text(body)
        return "{0}\n{1}".format(header, indented_body)

    @staticmethod
    def update_memory_with_read_page(lines: typing.List[str], read_page: str, step: int) -> typing.Tuple[
        typing.List[str], int]:
        # Remove all previous read_page entries
        lines = [line for line in lines if not line.startswith(f"{MemoryAgent.extract_step_number(line)}. Read page:")]

        # Renumber remaining steps
        for i, line in enumerate(lines[1:], start=1):
            lines[i] = f"{i}. {line.split('. ', 1)[1]}"

        # Add new read_page action with informative message
        if read_page:
            formatted_content = MemoryAgent.format_step_content(step, read_page)
            informative_message = (
                f"{step}. Read page: Adding to memory. Page content follows:\n"
                f"{MemoryAgent.indent_text(formatted_content)}"
            )
            lines.append(informative_message)
            step += 1

        return lines, step
