from absint_ai.Environment.agents.base_agent import Agent
from typing import TYPE_CHECKING
import beeprint
from openai import OpenAI, BadRequestError
import absint_ai.utils.Util as Util
from absint_ai.utils.Logger import logger
from absint_ai.Environment.agents.actions import AgentAction
import json

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


class LLMAgent(Agent):
    def __init__(self, api_key: str, base_url: str, model: str, use_agent:bool=True) -> None:
        self.mode = "selection"
        self.openai_client = OpenAI(api_key=api_key, base_url=base_url)  # type: ignore
        self.model = model
        self.total_tool_calls = 0
        self.too_long = False
        self.use_agent = use_agent
        if use_agent:
            self.system_prompt = {
                "role": "system",
                "content": """
        You are an expert static analysis assistant.

        Your task is to decide which heap allocation sites should be summarized in order to ensure that the program's abstract interpretation converges at a loop.

        If an allocation site produces objects whose values (e.g. fields or shapes) may grow or change across iterations, you should select that site for summarization. If an allocation site produces only one object or produces identical values each iteration, it may not need to be summarized.

        You may only respond by calling a tool. Do not produce natural language explanations unless instructed to do so.

        You must make decisions only after gathering the necessary information. Use the tools `info_var` and `execute_loop` provided to inspect variables or execute the loop. You should use them to gather information before making a decision. If you decide to use these tools, do NOT summarize the allocation sites in the same iteration. you should only summarize allocation sites after you have gathered all the necessary information. Specifically, `execute_loop` is ONLY used to gather information, nothing else. If you have executed the loop five times and you still don't have enough information, you should summarize the allocation sites anyway. Do NOT try to execute the loop more than five times in a row. If you do, you should summarize the allocation sites anyway. DO NOT REPEATLY EXECUTE THE FUCKING LOOP. This is a waste of time and resources. You should only execute the loop once to gather information. If you need more information, you should use the `info_var` tool.

        Be conservative in your choices — avoid summarizing sites that do not contribute to divergence.
        """,
            }
            
        else:
            self.system_prompt = {
                "role": "system",
                "content": """
        You are an expert static analysis assistant.

        Your task is to decide which heap allocation sites should be summarized in order to ensure that the program's abstract interpretation converges at a loop.

        If an allocation site produces objects whose values (e.g. fields or shapes) may grow or change across iterations, you should select that site for summarization. If an allocation site produces only one object or produces identical values each iteration, it may not need to be summarized.

        Be conservative in your choices — avoid summarizing sites that do not contribute to divergence.
        """
            }
        self.site_selection_messages = [self.system_prompt]
        self.message_chains = {}

    def decide_site_selection(
        self,
        env: "Environment",
        changed_allocation_sites: list[str],
        code: str,
        loop_iteration: int,
        loop_body: str,
        print_all_info: bool = False,
    ) -> "AgentAction":
        """Return the next action the agent wants to perform given the current analysis state."""
        allocation_site_values_raw = env.get_allocation_site_values(
            changed_allocation_sites, ignore_heap_frames=True
        )
        
        for allocation_site in allocation_site_values_raw:
            points_to_info = env.points_to_info_for_allocation_site(allocation_site)
            allocation_site_values_raw[allocation_site]["points_to"] = list(
                points_to_info
            )
        allocation_site_mapping = {}
        allocation_site_mapping_backwards = {}
        for i, allocation_site in enumerate(changed_allocation_sites):
            readable_allocation_site = env.get_readable_allocation_site(allocation_site)
            allocation_site_mapping[allocation_site] = readable_allocation_site
            allocation_site_mapping_backwards[readable_allocation_site] = (
                allocation_site
            )
        allocation_site_values = {
            allocation_site_mapping[k]: v for k, v in allocation_site_values_raw.items()
        }
        allowed_allocation_sites = list(allocation_site_values.keys())
        select_allocation_sites_tool = self._select_allocation_sites_tool(
            allowed_allocation_sites
        )
        info_var_tool = self._info_var_tool()
        info_function_tool = self._info_function_tool()
        execute_tool = self._execute_tool()

        all_tools = [select_allocation_sites_tool, info_var_tool, execute_tool]
        if self.total_tool_calls > 5 or not self.use_agent:
            all_tools = [select_allocation_sites_tool]
        if not self.too_long:
            allocation_site_values_str = beeprint.pp(allocation_site_values, output=False)
        else:
            allocation_site_values_str = ",".join([env.get_readable_allocation_site(allocation_site_id) for allocation_site_id in changed_allocation_sites])
        if self.use_agent:
            user_prompt = {
                "role": "user",
                "content": (
                    "The following code defines a loop. Some allocation sites changed during the first iteration.\n\n"
                    "=== Code ===\n"
                    f"{code}\n\n"
                    "=== Loop Body ===\n"
                    f"{loop_body}\n\n"
                    f"=== Changed Allocation Sites for Iteration #{loop_iteration} ===\n"
                    f"{allocation_site_values_str}\n\n"
                    "Choose which allocation sites should be summarized to ensure convergence."
                    "If you want more information about a variable, use the `info_var` tool with a variable name to get the current value of that variable. \n\n"
                    "If you want to execute the loop once for more information, use the `execute_loop` tool.\n\n"
                ),
            }
        else:
            user_prompt = {
                "role": "user",
                "content": (
                    "The following code defines a loop. Some allocation sites changed during the first iteration.\n\n"
                    "=== Code ===\n"
                    f"{code}\n\n"
                    "=== Loop Body ===\n"
                    f"{loop_body}\n\n"
                    f"=== Changed Allocation Sites for Iteration #{loop_iteration} ===\n"
                    f"{allocation_site_values_str}\n\n"
                    "Choose which allocation sites should be summarized to ensure convergence."
                ),
            }
        self.site_selection_messages.append(user_prompt)
        try:
            response = self.openai_client.chat.completions.create(
                model=self.model,
                messages=self.site_selection_messages,
                tools=all_tools,
                tool_choice="required",
            )
        except BadRequestError as e:
            logger.error(f"Error in OpenAI API call: {e}")
            if "length" in str(e):
                self.too_long = True
            self.site_selection_messages = [self.system_prompt] # if its' broken start over
            logger.info(
                f"Request: {beeprint.pp(self.site_selection_messages, output=False)}"
            )
            return AgentAction("continue", {})
        function_info = response.choices[0].message.tool_calls[0].function
        function_id = response.choices[0].message.tool_calls[0].id
        function_name = function_info.name
        function_args = json.loads(function_info.arguments)
        #print(f"got message from LLM: {response.choices[0].message}")
        all_tool_calls = response.choices[0].message.tool_calls
        all_info_tool_calls = [
            tool_call
            for tool_call in all_tool_calls
            if "info" in tool_call.function.name
        ]
        execute_tool_calls = [
            tool_call for tool_call in all_tool_calls if "execute" in tool_call.function.name
        ]

        # first we need to provide all the information messages to the LLM
        if len(all_info_tool_calls) > 0:
            self.total_tool_calls += 1
            self.site_selection_messages.append(
                response.choices[0].message.model_dump()
            )
            for tool_call in all_info_tool_calls:
                function_info = tool_call.function
                function_name = function_info.name
                function_args = json.loads(function_info.arguments)
                if function_name == "info_var":
                    var_name = function_args.get("var_name") or function_args.get(
                        "name"
                    )

                    var_lookup_results = env.lookup_and_derive(
                        var_name, add_touched=False
                    )
                    if len(var_lookup_results) == 0:
                        self.site_selection_messages.append(
                            {
                                "role": "tool",
                                "name": "info_var",
                                "tool_call_id": tool_call.id,
                                "content": f"=== Variable {var_name} has no values at this point. ===",
                            }
                        )
                    else:
                        result = f"\n === Variable {var_name} has the following values: ===\n"
                        result += "=====\n".join(
                            beeprint.pp(_, output=False) for _ in var_lookup_results
                        )
                        # Append the tool call result
                        self.site_selection_messages.append(
                            {
                                "role": "tool",
                                "name": "info_var",
                                "tool_call_id": tool_call.id,
                                "content": result,
                            }
                        )

                elif function_name == "info_function":
                    pass
            for execute_tool_call in execute_tool_calls:
                function_info = execute_tool_call.function
                function_name = function_info.name
                function_args = json.loads(function_info.arguments)
                if function_name == "execute_loop":
                    self.site_selection_messages.append(
                        {
                            "role": "tool",
                            "name": "execute_loop",
                            "tool_call_id": execute_tool_call.id,
                            "content": "The loop was executed once. State updated.",
                        }
                    )
                    return AgentAction("execute", {})
            return AgentAction("continue", {})

        if function_name == "select_allocation_sites":
            #Util.pretty_print_messages(self.site_selection_messages)
            selected_sites = selected_sites = [
                allocation_site_mapping_backwards[arg]
                for arg in function_args["selected_sites"]
                if arg in allocation_site_mapping_backwards
            ]

            return AgentAction("select", {"sites": selected_sites})
        elif function_name == "execute_loop":
            self.site_selection_messages.append(
                response.choices[0].message.model_dump()
            )
            tool_call_result_message = {
                "role": "tool",
                "tool_call_id": function_id,
                "name": "execute_loop",
                "content": "The loop was executed once. State updated.",
            }
            self.site_selection_messages.append(tool_call_result_message)
            return AgentAction("execute", {})

        else:
            raise ValueError(
                f"Unknown function name: {function_name}. Expected 'select_allocation_sites'."
            )

    def decide_merging_strategy_for_site(
        self, env: "Environment", allocation_site_id: str, code: str, loop_body: str
    ) -> "AgentAction":
        """Return the next action the agent wants to perform given the current analysis state."""
        readable_allocation_site_id = env.get_readable_allocation_site(
            allocation_site_id
        )
        allocation_site_value_str = Util.get_allocation_site_value_str(
            env, allocation_site_id
        )
        merge_strategy_tool = {
            "type": "function",
            "function": {
                "name": "select_merge_strategy",
                "description": "Choose a merge strategy for a specific allocation site.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "strategy": {
                            "type": "string",
                            "enum": ["all", "recency", "role"],
                            "description": "The name of the merging strategy to apply.",
                        },
                        "field": {
                            "type": "string",
                            "description": "Which field to use for the role-based merging. (only used if strategy is 'role').",
                        },
                    },
                    "required": ["strategy"],
                },
            },
        }
        user_prompt = {
            "role": "user",
            "content": (
                f"You are configuring the merge strategy for {readable_allocation_site_id}. Here is the information,\n\n"
                "=== Code ===\n"
                f"{code}\n\n"
                "=== Loop Body ===\n"
                f"{loop_body}\n\n"
                "=== Allocation Site Value ===\n"
                f"{allocation_site_value_str}\n\n"
                "Choose a strategy for merging the values of this allocation site.\n\n"
                "Valid options are: all, recency.\n\n"
                "- all: simply merge all possible addresses and primitives together into one mega-object. Do this if the object is not being re-allocated frequently, and is just one object being modified. \n"
                "- recency: Keep the most recent values of the allocation site, and merge the old ones into a single summary node. Only do this is the object is being re-allocated frequently.\n\n"
                "- role: merge objects that have a similar role. Specify the role as a field name, and all addresses that have the same value for that field will be merged together. For example, if you want to keep a separate abstract object for different user roles, or different types of AST nodes. If you choose this strategy, you MUST provide the field as a parameter.\n\n"
            ),
        }
        if allocation_site_id not in self.message_chains:
            self.message_chains[allocation_site_id] = [self.system_prompt]

        messages = self.message_chains[allocation_site_id]
        # messages.append(user_prompt)
        response = self.openai_client.chat.completions.create(
            model=self.model,
            messages=[self.system_prompt, user_prompt],
            tools=[merge_strategy_tool],
            tool_choice="required",
        )
        function_info = response.choices[0].message.tool_calls[0].function
        message_from_llm = response.choices[0].message
        function_name = function_info.name
        function_args = json.loads(function_info.arguments)
        logger.info(
            f"Function name: {function_name}, Function args: {function_args} for allocation site {readable_allocation_site_id}"
        )
        if function_name == "select_merge_strategy":
            strategy = function_args["strategy"]
            if strategy not in ["all", "recency", "role"]:
                raise ValueError(
                    f"Invalid strategy: {strategy}. Expected 'all' or 'recency'."
                )
            if strategy == "all" or strategy == "recency":
                return AgentAction(
                    "strategy", {"site": allocation_site_id, "strategy": strategy}
                )
            elif strategy == "role":
                role = function_args.get("field", None)
                if role is None:
                    logger.info(
                        f"Role is None for allocation site {allocation_site_id}. Using recency strategy."
                    )
                    return AgentAction(
                        "strategy",
                        {
                            "site": allocation_site_id,
                            "strategy": "recency",
                        },
                    )
                else:
                    return AgentAction(
                        "strategy",
                        {
                            "site": allocation_site_id,
                            "strategy": strategy,
                            "field": role,
                        },
                    )
        else:
            raise ValueError(
                f"Unknown function name: {function_name}. Expected 'select_merge_strategy'."
            )

    # TODO I should also put in the previous merging strategy
    def decide_widening_strategy_for_site(
        self, env: "Environment", allocation_site_id: str, code: str, loop_body: str
    ) -> "AgentAction":
        """Return the next action the agent wants to perform given the current analysis state."""
        readable_allocation_site_id = env.get_readable_allocation_site(
            allocation_site_id
        )
        allocation_site_value_str = Util.get_allocation_site_value_str(
            env, allocation_site_id
        )
        widening_strategy_tool = {
            "type": "function",
            "function": {
                "name": "select_widening_strategy",
                "description": (
                    "Choose a widening strategy for a specific allocation site. "
                    "This is used to force convergence when field values grow across loop iterations."
                ),
                "parameters": {
                    "type": "object",
                    "properties": {
                        "strategy": {
                            "type": "string",
                            "enum": ["field_value", "all", "none", "depth"],
                            "description": (
                                "'field_value' means widen selected fields. This is good when only a few fields are changing and you want to keep everything else other than that field concrete. Provide a space-separated list of field paths using dot notation.\n"
                                "- field_set: Combine all the fields into one. For example, if there is an infinitely growing list, you can combine all the fields into a single field. This is a good option if the object itself is growing. If a new object is being pushed at each iteration, you will need to decide on a separate merging/widening strategy for that.\n"
                                "- all:  widen the entire object. \n"
                                " -none: no widening is necessary.\n"
                                "- depth:  widen all values after a particular depth. Provide the depth as an integer. If the depth is 1, it will widen all field values. If the depth is 2, it will find all field values 2 levels deep and widen them, etc. This is a good option if you have a lot of fields that are all changing.  REMEMBER DEPTH IS 1-BASED.\n"
                            ),
                        },
                        "fields": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "Dot-separated paths of fields to widen (only used if strategy is 'field_value').",
                            "default": [],
                        },
                        "depth": {
                            "type": "integer",
                            "description": "The depth of the object to abstract. (only used if strategy is 'depth').",
                            "default": 1,
                        },
                    },
                    "required": ["strategy"],
                },
            },
        }
        user_prompt = {
            "role": "user",
            "content": (
                f"You are configuring the widening strategy for {readable_allocation_site_id}. Here is the information,\n\n"
                "=== Code ===\n"
                f"{code}\n\n"
                "=== Loop Body ===\n"
                f"{loop_body}\n\n"
                "=== Allocation Site Value ===\n"
                f"{allocation_site_value_str}\n\n"
                "Choose a strategy for merging the values of this allocation site.\n\n"
                "Valid options are: field_value, all, depth.\n\n"
                "- field_value: widen the value for a few particular fields. You might do this if only a few fields are growing. Provide a space-separate list of field paths using dot notation.\n"
                "- all: widen the entire thing. This is very imprecise, so use it sparingly.\n\n"
                "- depth:  widen all values after a particular depth. Provide the depth as an integer. If the depth is 1, it will widen all field values. If the depth is 2, it will find all field values 2 levels deep and widen them, etc. This is a good option if you have a lot of fields that are all changing. This is a very good one. REMEMBER DEPTH IS 1-BASED.\n\n"
            ),
        }
        if allocation_site_id not in self.message_chains:
            self.message_chains[allocation_site_id] = [self.system_prompt]
        messages = self.message_chains[allocation_site_id]
        messages.append(user_prompt)
        response = self.openai_client.chat.completions.create(
            model=self.model,
            messages=[self.system_prompt, user_prompt],
            tools=[widening_strategy_tool],
            tool_choice="required",
        )

        function_info = response.choices[0].message.tool_calls[0].function
        function_name = function_info.name
        function_args = json.loads(function_info.arguments)
        logger.info(
            f"Function name: {function_name}, Function args: {function_args} for allocation site {readable_allocation_site_id}"
        )
        if function_name == "select_widening_strategy":
            strategy = function_args["strategy"]
            if strategy not in ["field_value", "all", "none", "depth"]:
                raise ValueError(
                    f"Invalid strategy: {strategy}. Expected 'field_value', 'all', or 'none'."
                )
            fields = None
            depth = None
            if strategy == "field_value":
                fields = function_args.get("fields", [])
            elif strategy == "depth":
                depth = function_args.get("depth", 1)

            return AgentAction(
                "strategy",
                {
                    "site": allocation_site_id,
                    "strategy": strategy,
                    "field_paths": fields,
                    "depth": depth,
                },
            )
        pass

    def decide_primitives(
        self,
        env: "Environment",
        changed_vars: list[str],
        code: str = None,
        loop_body: str = None,
    ) -> "AgentAction":
        """Return the next action the agent wants to perform given the current analysis state."""
        changed_primitive_vars = [
            var for var in changed_vars if env.is_primitive_variable(var)
        ]
        primitive_var_values = env.get_all_reachable_object_variable_values(
            changed_primitive_vars
        )
        select_primitive_variables_tool = {
            "type": "function",
            "function": {
                "name": "select_primitive_variables",
                "description": (
                    "Select which primitive variables should be widened to ensure convergence. "
                    "These are variables that hold numbers, booleans, or strings that change across loop iterations."
                ),
                "parameters": {
                    "type": "object",
                    "properties": {
                        "variables": {
                            "type": "array",
                            "items": {"type": "string"},
                            "description": "A list of primitive variable names that should be widened.",
                        }
                    },
                    "required": ["variables"],
                },
            },
        }
        changed_var_string = "\n".join(changed_primitive_vars)
        user_content = []
        user_content.append(
            f"You are now deciding what primitive variables need to be abstracted.\n\n"
        )
        if code:
            user_content.append("=== Code ===\n")
            user_content.append(f"{code}\n\n")
        if loop_body:
            user_content.append("=== Loop Body ===\n")
            user_content.append(f"{loop_body}\n\n")
        user_content.append(f"=== Changed Primitive Variables ===\n")
        user_content.append(f"{changed_var_string}\n\n")
        user_content.append(f"=== Primitive Variable Values ===\n")
        user_content.append(f"{beeprint.pp(primitive_var_values, output=False)}\n\n")
        user_content.append("Select which primitive variables should be widened\n\n")
        user_content.append("Example: `select counter sum`")
        user_prompt = {
            "role": "user",
            "content": "".join(user_content),
        }
        response = self.openai_client.chat.completions.create(
            model=self.model,
            messages=[self.system_prompt, user_prompt],
            tools=[select_primitive_variables_tool],
            tool_choice="required",
        )
        function_info = response.choices[0].message.tool_calls[0].function
        function_name = function_info.name
        function_args = json.loads(function_info.arguments)
        logger.info(
            f"Function name: {function_name}, Function args: {function_args} for primitives"
        )
        if function_name == "select_primitive_variables":
            selected_vars = function_args["variables"]
            selected_vars = [
                var for var in selected_vars if var in changed_primitive_vars
            ]
            return AgentAction("select", {"variables": selected_vars})
        else:
            raise ValueError(
                f"Unknown function name: {function_name}. Expected 'select_primitive_variables'."
            )
        pass

    def receive_info(self, info: str) -> None:
        """Receive information from the environment."""
        pass

    def _select_allocation_sites_tool(
        self, allowed_allocation_sites: list[str]
    ) -> "AgentAction":
        """Generates the tool as a JSON object forthe LLM"""
        return {
            "type": "function",
            "function": {
                "name": "select_allocation_sites",
                "description": "Select allocation sites that should be summarized to ensure convergence of the loop.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "selected_sites": {
                            "type": "array",
                            "items": {
                                "type": "string",
                                "enum": allowed_allocation_sites,
                            },
                            "description": "A list of allocation sites to summarize",
                        }
                    },
                    "required": ["selected_sites"],
                },
            },
        }

    def _info_var_tool(self) -> "AgentAction":
        """Generates the tool as a JSON object for the LLM"""
        info_var_tool = {
            "type": "function",
            "function": {
                "name": "info_var",
                "description": "Get the current value or summary of a variable in the environment.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "var_name": {
                            "type": "string",
                            "description": "The name of the variable to inspect.",
                        }
                    },
                    "required": ["var_name"],
                },
            },
        }
        return info_var_tool

    def _info_function_tool(self) -> "AgentAction":
        """Generates the tool as a JSON object for the LLM"""
        info_function_tool = {
            "type": "function",
            "function": {
                "name": "info_function",
                "description": "Get the current definition of a function by name in the environment.",
                "parameters": {
                    "type": "object",
                    "properties": {
                        "function_name": {
                            "type": "string",
                            "description": "The name of the function to inspect.",
                        }
                    },
                    "required": ["function_name"],
                },
            },
        }
        return info_function_tool

    def _execute_tool(self) -> dict:
        return {
            "type": "function",
            "function": {
                "name": "execute_loop",
                "description": "Execute the loop once to update the environment state.",
                "parameters": {},
            },
        }
