from absint_ai.Environment.agents.base_agent import Agent
from typing import TYPE_CHECKING
import beeprint
import absint_ai.utils.Util as Util

from absint_ai.Environment.agents.actions import AgentAction

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


class HumanCLIAgent(Agent):
    def __init__(self):
        self.mode = "selection"

    def decide_site_selection(
        self,
        env: "Environment",
        changed_allocation_sites: list[str],
        code: str,
        loop_body: str,
        loop_iteration: int,
        print_all_info: bool = False,
    ) -> "AgentAction":
        allocation_site_values_raw = env.get_allocation_site_values(
            changed_allocation_sites, ignore_heap_frames=True
        )
        for allocation_site in allocation_site_values_raw:
            points_to_info = env.points_to_info_for_allocation_site(allocation_site)
            allocation_site_values_raw[allocation_site]["points_to"] = list(
                points_to_info
            )
        allocation_site_mapping = {}
        allocation_site_mapping_backwards = {}
        for i, allocation_site in enumerate(changed_allocation_sites):
            readable_allocation_site = env.get_readable_allocation_site(allocation_site)
            allocation_site_mapping[allocation_site] = readable_allocation_site
            allocation_site_mapping_backwards[readable_allocation_site] = (
                allocation_site
            )
        allocation_site_values = {
            allocation_site_mapping[k]: v for k, v in allocation_site_values_raw.items()
        }
        allocation_site_values_str = beeprint.pp(allocation_site_values, output=False)
        if print_all_info:
            print("\n=== The code for the current loop iteration ===")
            print(code)
            print(f"\n=== The loop body ===")
            print(loop_body)
            print(
                f"\n=== Changed Allocation Sites (Loop Iteration #{loop_iteration}) ==="
            )
            print(f"{allocation_site_values_str}")

        print("\nCommands:")
        print("  select <Alloc@id> ...         # Select sites to summarize")
        print("  info var <name>               # Show function source")
        print("  info function <name>          # Show function source")
        print("  execute                       # Execute the loop for more info")
        # print("  info loop <loop_id>           # Show loop context")
        # print("  info alloc <alloc@id>         # Show heap info for an alloc site")
        print("  exit                          # Do not abstract any allocation sites")

        while True:
            line = input("agent(selection)> ").strip()
            if not line:
                continue
            tokens = line.split()
            cmd = tokens[0]
            args = tokens[1:]

            if cmd == "select":
                if not args:
                    print("⚠️ Please specify at least one allocation site.")
                    continue
                selected_sites = [
                    allocation_site_mapping_backwards[arg]
                    for arg in args
                    if arg in allocation_site_mapping_backwards
                ]
                print(
                    f"Selected allocation sites: {beeprint.pp(selected_sites, output=False)}"
                )
                self.selected_sites = selected_sites
                self.mode = "strategy"
                return AgentAction("select", {"sites": selected_sites})

            elif cmd == "info" and len(args) == 2:
                return AgentAction("info", {"type": args[0], "target": args[1]})
            elif cmd == "execute":
                return AgentAction("execute", {})

            # elif cmd == "noop":
            #    return AgentAction("noop", {})

            elif cmd == "exit":
                return AgentAction("exit", {})

            print("❌ Invalid command.")

    def decide_merging_strategy_for_site(
        self,
        env: "Environment",
        allocation_site_id: str,
        code: str = None,
        loop_body: str = None,
    ) -> "AgentAction":

        readable_allocation_site_id = env.get_readable_allocation_site(
            allocation_site_id
        )
        allocation_site_value_str = Util.get_allocation_site_value_str(
            env, allocation_site_id
        )
        if code is not None:
            print(f"\n--- The code for the entire program is ---")
            print(code)
        print(f"\n--- Configuring Merge Strategy for {readable_allocation_site_id} ---")
        print(f"The values for this allocation site are:")
        print(f"{allocation_site_value_str}")
        """
        - field_set: merge objects that have the same set of fields
- field_value: merge objects that have the same value for a particular field. 

        """
        print(
            f"""NOTE: right now the only ones that work are `all` and `recency`.
The possible strategies are:
- all: simply merge all possible addresses and primitives together into one mega-object
- recency: merge all previous values of the allocation site into one object, and keep the most recent value as a separate object
- role: merge objects that have a similar role. Specify the role as a field name, and all addresses that have the same value for that field will be merged together.
"""
        )
        print("Commands:")
        print(
            "  strategy <strategy_name>      # Choose merge strategy (e.g., all, recency, field_set, field_value)"
        )
        # print("  info function <name>          # Show function source")
        # print("  info alloc <Alloc@id>         # Show heap info")
        print("  exit                          # Abort")
        # Get the merging strategy
        merge_strategy = None
        while True:
            line = input(f"agent(strategy:{readable_allocation_site_id})> ").strip()
            if not line:
                continue
            tokens = line.split()
            cmd = tokens[0]
            args = tokens[1:]

            if cmd == "strategy":
                if not args:
                    print("⚠️ Please specify a strategy.")
                    continue
                merge_strategy = args[0]
                if merge_strategy not in [
                    "all",
                    "recency",
                    "field_set",
                    "role",
                ]:
                    print(
                        f"⚠️ Invalid strategy: {merge_strategy}. Valid options are: all, recency, field_set, field_value."
                    )
                    continue
                else:
                    if merge_strategy == "all" or merge_strategy == "recency":
                        return AgentAction(
                            "strategy",
                            {
                                "strategy": merge_strategy,
                            },
                        )
                    elif merge_strategy == "role":
                        field = args[1]
                        if not field:
                            print("⚠️ Please specify a field.")
                            continue
                        
                        return AgentAction(
                            "strategy",
                            {
                                "strategy": merge_strategy,
                                "field": field,
                            },
                        )

            # elif cmd == "info" and len(args) == 2:
            #    return AgentAction("info", {"type": args[0], "target": args[1]})

            elif cmd == "exit":
                return AgentAction("exit", {})

            print("❌ Invalid command.")

    def decide_widening_strategy_for_site(
        self,
        env: "Environment",
        allocation_site_id: str,
        code: str = None,
        loop_body: str = None,
    ) -> "AgentAction":

        readable_allocation_site_id = env.get_readable_allocation_site(
            allocation_site_id
        )
        allocation_site_value_str = Util.get_allocation_site_value_str(
            env, allocation_site_id
        )
        if code is not None:
            print(f"\n--- The code for the entire program is ---")
            print(code)
        print(
            f"\n--- Configuring Widening Strategy for {readable_allocation_site_id} ---"
        )
        print(f"The values for this allocation site are:")
        print(f"{allocation_site_value_str}")
        """
        - field key: merge a set of fields into one.
- depth limited: widen all values after a particular depth.
        """
        print(
            f"""What widening strategy would you like to use for this allocation site?
The possible strategies are:
- field_value: widen the value for a few particular fields. You might do this if only a few fields are growing. Provide a space-separate list of field paths using dot notation.
- depth: widen all values after a particular depth. Provide the depth as an integer. If the depth is 1, it will widen all field values. If the depth is 2, it will find all field values 2 levels deep and widen them, etc. This is a good option if you have a lot of fields that are all changing.
- field_set: Combine all the fields into one. For example, if there is an infinitely growing list, you can combine all the fields into a single field. This is a good option if the object itself is growing. You'll need to decide on a separate merging/widening strategy for the values themselves.
- all: widen the entire thing. Do this if every field of the object is growing. 
- none: do not widen anything. Do this if the object will converge without only merging and no widening."""
        )
        print("Commands:")
        print(
            "  strategy <strategy_name>      # Choose merge strategy (e.g., all, field value, field key, depth, none)"
        )
        # print("  info function <name>          # Show function source")
        # print("  info alloc <Alloc@id>         # Show heap info")
        # print("  exit                          # Abort")
        # Get the merging strategy
        widening_strategy = None
        while True:
            line = input(f"agent(strategy:{readable_allocation_site_id})> ").strip()
            if not line:
                continue
            tokens = line.split()
            cmd = tokens[0]
            args = tokens[1:]

            if cmd == "strategy":
                if not args:
                    print("⚠️ Please specify a strategy.")
                    continue
                widening_strategy = args[0]

                if widening_strategy not in [
                    "field_value",
                    "field_set",
                    "depth",
                    "all",
                    "none",
                ]:
                    print(
                        f"⚠️ Invalid strategy: {widening_strategy}. Valid options are: field_value, field_key, depth_limited, all, none."
                    )
                    continue
                else:
                    if widening_strategy == "all" or widening_strategy == "none" or widening_strategy == "field_set":
                        return AgentAction(
                            "strategy",
                            {
                                "strategy": widening_strategy,
                            },
                        )
                    elif widening_strategy == "field_value":
                        fields = args[1:]
                        if not fields:
                            print("⚠️ Please specify at least one field.")
                            continue
                        return AgentAction(
                            "strategy",
                            {
                                "strategy": widening_strategy,
                                "field_paths": fields,
                            },
                        )
                    elif widening_strategy == "depth":
                        if len(args) != 2:
                            print("⚠️ Please specify a depth.")
                            continue
                        try:
                            depth = int(args[1])
                        except ValueError:
                            print("⚠️ Depth must be an integer.")
                            continue
                        return AgentAction(
                            "strategy",
                            {
                                "strategy": widening_strategy,
                                "depth": depth,
                            },
                        )
                    else:
                        raise Exception(
                            f"Widening strategy {widening_strategy} not implemented yet."
                        )

            # elif cmd == "info" and len(args) == 2:
            #    return AgentAction("info", {"type": args[0], "target": args[1]})

            # elif cmd == "exit":
            #    return AgentAction("exit", {})

            print("❌ Invalid command.")

    def decide_primitives(
        self,
        env: "Environment",
        changed_vars: list[str],
        code: str = None,
        loop_body: str = None,
    ) -> "AgentAction":
        changed_primitive_vars = [
            var for var in changed_vars if env.is_primitive_variable(var)
        ]
        primitive_var_values = env.get_all_reachable_object_variable_values(
            changed_primitive_vars
        )
        if code is not None:
            print(f"\n--- The code for the entire program is ---")
            print(code)
        if loop_body is not None:
            print(f"\n--- The loop body is ---")
            print(loop_body)
        print("\n=== Primitive Variables That Changed ===")
        for var in changed_primitive_vars:
            print(f"  - {var}")
        print("\n=== Values ===")
        print(beeprint.pp(primitive_var_values, output=False))
        print("\nSelect which variables should be widened.")
        print("Example: `select counter sum`")
        print("Or type `noop` to skip.")

        while True:
            line = input("agent(primitives)> ").strip()
            tokens = line.split()
            if not tokens:
                continue

            if tokens[0] == "noop":
                return AgentAction("noop", {})

            if tokens[0] == "select":
                return AgentAction("select", {"variables": tokens[1:]})

            print("❌ Invalid command. Use `primitives var1 var2` or `noop`.")

    def receive_info(self, info: str) -> None:
        """
        Receive information from the environment.
        """
        print(info)
