from typing import Tuple, Dict, Union
import copy
import esprima

import pythonmonkey as pm
import logging
from ordered_set import OrderedSet

import absint_ai.Environment.abstractions.LLMAbstractions as LLMAbstractions
import absint_ai.Environment.abstractions.LLMNaiveAbstractions as LLMNaiveAbstractions
import absint_ai.Environment.abstractions.AllocationSiteAbstractions as AllocationSiteAbstractions
import absint_ai.Environment.EnvUtils.Visualize as Visualize
from absint_ai.Environment.memory.AbstractHeap import AbstractHeap
from absint_ai.Environment.types.MemoizedFunction import MemoizedFunction
from absint_ai.Environment.types.Type import *
from absint_ai.Environment.memory.StackFrame import StackFrame
from absint_ai.Environment.memory.ConcreteHeap import ConcreteHeap
from absint_ai.utils.Util import *
import absint_ai.Environment.agents.merging as merging
import absint_ai.Environment.features.functions as functions
import absint_ai.Environment.EnvUtils.PrettyPrint as PrettyPrint
from absint_ai.utils.Logger import logger  # type: ignore
from absint_ai.Environment.memory.RecordResult import RecordResult
import absint_ai.Environment.features.summarization as summarization
import absint_ai.Environment.EnvUtils.EnvUtils as EnvUtils
import absint_ai.Environment.features.builtins.document as document
import absint_ai.Environment.EnvUtils.Update as Update
from absint_ai.Environment.EnvUtils.AllocationSite import AllocationSite
import absint_ai.Environment.EnvUtils.BookKeeping as BookKeeping
from absint_ai.Environment.agents.merging import *
from absint_ai.Environment.agents.widening import *
from absint_ai.Environment.agents.HumanCLIAgent import HumanCLIAgent
from absint_ai.Environment.agents.LLMAgent import LLMAgent
from absint_ai.Environment.agents.actions import *

logging.getLogger("matplotlib.font_manager").disabled = True


class Environment:
    # region initialization and finalization
    def __init__(  # type: ignore
        self,
        concrete_heap: ConcreteHeap = None,  # type: ignore
        abstract_heap: AbstractHeap = None,  # type: ignore
        stack: dict[str, list[StackFrame]] = None,  # type: ignore
    ):
        if concrete_heap is None:
            concrete_heap = ConcreteHeap()
        self.concrete_heap: ConcreteHeap = concrete_heap
        if abstract_heap is None:
            abstract_heap = AbstractHeap()
        self.abstract_heap = abstract_heap
        if stack is None:
            stack = {}
        self.stack = stack
        self.first_iteration = True
        self.cur_stack_frame: StackFrame = None  # type: ignore
        self.num_input_tokens = 0
        self.num_output_tokens = 0
        self.logging = True
        self.scope_schema: dict[str, dict[str, list]] = None  # type: ignore
        self.should_visualize = False
        self.memoized_functions: list[MemoizedFunction] = []
        self.initialized_global = False
        # The exported identifiers are used to keep track of the identifiers that are exported from a module
        self.exported_identifiers: dict[str, list[str]] = {}
        self.exports_for_modules: dict[str, RecordResult] = {}
        self.global_stack_frame = None
        self.object_prototype = None
        self.allocation_sites: dict[str, dict[str, object]] = (
            {}
        )  # Dictionary of allocation site ID as a string to the abstract address allocated
        self.model = ""
        self.add_object_prototype()
        self.abstracted_addresses_map: dict[str, str] = {}
        self.openai_client: OpenAI = None  # type: ignore
        self.builtins: list[str] = ["document"]
        self.entrypoint_functions: OrderedSet = OrderedSet()
        self.builtin_addresses: list[Address] = []
        self.loop_primitives_to_summarize: list[str] = {}
        self.global_primitives_to_summarize: set[str] = OrderedSet()
        self.function_definitions = (
            {}
        )  # storing the function name to the definitions as strings
        self.variable_state = (
            {}
        )  # Used for debugging. Only gets updated if absint_ai.debug is on
        self.summarized_loop_ids = OrderedSet()

    def add_schema(self, schema: dict[str, dict[str, list]]) -> None:
        BookKeeping.add_schema(self, schema)

    def get_schema_for_scope_id(self, scope_id: str) -> dict[str, list]:
        return self.scope_schema[scope_id]

    def schema_contains_scope_id(self, scope_id: str) -> bool:
        return scope_id in self.scope_schema

    # Clears everything except for abstract heap and allocation sites
    def finish_iteration(self, simplify: bool = True) -> None:
        BookKeeping.finish_iteration(self, simplify)

    def initialize_openai(self, api_key: str, base_url: str) -> None:
        BookKeeping.initialize_openai(self, api_key, base_url)

    def get_stack_depth(self, file_path: str) -> int:
        return BookKeeping.get_stack_depth(self, file_path)

    def initialize_from_schema(
        self,
        schema: dict[str, object],
        file_path: str,
        allocation_site: str,
        parent_address: Address | None,
        this_address: Address = None,  # type: ignore
        abstract_parents: bool = False,
        is_module: bool = False,
        is_function: bool = True,
    ) -> None:
        BookKeeping.initialize_from_schema(
            self,
            schema,
            file_path,
            allocation_site,
            parent_address,
            this_address,
            abstract_parents,
            is_module,
            is_function=is_function,
        )

    # endregion

    # region Adding/Updating

    def update_global_env(self, global_schema: dict[str, object]) -> None:
        Update.update_global_env(self, global_schema)

    def add(
        self,
        var_name: str,
        value: object,
        kind: str = "var",
        overwrite: bool = False,
        is_abstraction: bool = False,
    ) -> None:
        Update.add(self, var_name, value, kind, overwrite, is_abstraction)

    def add_object_to_heap(
        self,
        obj: dict,
        allocation_site: str | None,
        add_proto: bool = True,
        allocation_site_type: str = "object",
    ) -> Address:
        return Update.add_object_to_heap(
            self, obj, allocation_site, add_proto, allocation_site_type
        )

    def add_object_to_abstract_heap(
        self,
        obj: dict,
        add_proto: bool = True,
        allocation_site=None,
        parent_address=None,
        allocation_site_type: str = "object",
        add_to_concrete:bool=True
    ) -> Address:
        return Update.add_object_to_abstract_heap(
            self, obj, add_proto, allocation_site, parent_address, allocation_site_type, add_to_concrete=add_to_concrete
        )

    # region Functions
    # used for FunctionDeclarations
    def add_function(
        self,
        name: str,
        params: list[str],
        value: esprima.nodes.BlockStatement,
        scope_id: str,
        file_path: str,
        rest: str = None,
        full_function_node: esprima.nodes.FunctionDeclaration = None,
    ) -> None:  # expr type and loc used to get the schema for the function
        if len(name) > 0:
            self.function_definitions[name] = Util.ast_to_str(full_function_node)
        Update.add_function(self, name, params, value, scope_id, file_path, rest)

    def initialize_new_function(
        self,
        name: str,
        params: list[str],
        value: esprima.nodes.BlockStatement,
        scope_id: str,
        file_path: str,
        full_function_node: esprima.nodes.FunctionDeclaration = None,
    ) -> Address:
        if len(name) > 0:
            self.function_definitions[name] = Util.ast_to_str(full_function_node)
        return Update.initialize_new_function(
            self, name, params, value, scope_id, file_path
        )

    # endregion
    def update(
        self,
        address: Address,
        possibleProps: list[Type],
        possibleValues: list[Type],
        overwrite: bool = False,
        objName: str | None = None,
    ) -> None:
        Update.update(self, address, possibleProps, possibleValues, overwrite, objName)

    def add_value_for_field(
        self, address: Address, field: str, value: Type, overwrite=False
    ) -> None:
        Update.add_value_for_field(self, address, field, value, overwrite)

    def update_this(
        self,
        possibleProps: list[Type],
        possibleValues: list[Type],
        overwrite: bool = False,
    ) -> None:
        Update.update_this(self, possibleProps, possibleValues, overwrite)

    # endregion

    # region Lookup Functions

    def lookup_values_for_field_path(
        self, addr: Address, field_path: list[str]
    ) -> list[RecordResult]:
        return EnvUtils.lookup_values_for_field_path(self, addr, field_path)

    def lookup_values_for_depth(self, addr: Address, depth: int) -> list[RecordResult]:
        return EnvUtils.lookup_values_for_depth(self, addr, depth)

    def get_function_definition(self, name: str) -> str:
        if name in self.function_definitions:
            return self.function_definitions[name]
        return None

    def get_readable_allocation_site(self, allocation_site_id: str) -> str:
        if allocation_site_id in self.allocation_sites:
            return self.allocation_sites[allocation_site_id]["readable_allocation_site"]
        else:
            return ""

    def lookup_address(self, addr: Address) -> dict[str | AbstractType, RecordResult]:
        return EnvUtils.lookup_address(self, addr)

    def lookup(
        self, name: str, add_touched: bool = True, loc: str = None, log: bool = False
    ) -> RecordResult:
        return EnvUtils.lookup(self, name, add_touched, loc, log)

    def lookup_field(
        self, address: Address, field: Type, searched_addresses=None, line_number=None
    ) -> RecordResult:
        return EnvUtils.lookup_field(
            self, address, field, searched_addresses, line_number
        )

    def points_to_info_for_allocation_site(self, allocation_site: str) -> dict:
        return EnvUtils.points_to_info_for_allocation_site(self, allocation_site)

    def get_allocation_site_values(
        self, allocation_sites: list[str], ignore_heap_frames=True
    ) -> dict:
        return EnvUtils.get_allocation_site_values(
            self, allocation_sites, ignore_heap_frames=ignore_heap_frames
        )

    def lookup_and_derive_address(
        self, addr: Address, log=False, ignore_proto=False
    ) -> dict:
        result = self.lookup_address(addr)
        return {
            str(k): [EnvUtils.type_to_value(self,_) for _ in v.get_all_values()]
            for k, v in result.items()
            if (k != "__meta__") and (k != "__proto__" or not ignore_proto)
        }

    def lookup_and_derive(self, name: str, add_touched: bool = True) -> list:
        record_result = self.lookup(name, add_touched=add_touched)
        return [
            self.type_to_value(val) for val in record_result.get_all_values()
        ]  # Maybe I can sort these recursively

    # endregion

    # region Splitting and Merging the state
    # Splits the environment,
    def split_state(self, file_path) -> Tuple[ConcreteHeap, AbstractHeap, StackFrame]:
        concrete_heap_copy, abstract_heap_copy, stack_copy = (
            copy.deepcopy(self.concrete_heap),
            copy.deepcopy(self.abstract_heap),
            copy.deepcopy(self.cur_stack_frame),
        )
        return concrete_heap_copy, abstract_heap_copy, stack_copy

    def set_state(
        self,
        concrete_heap: ConcreteHeap,
        abstract_heap: AbstractHeap,
        stack_frame: StackFrame,
        file_path: str,
    ) -> None:
        self.cur_stack_frame = copy.deepcopy(stack_frame)
        self.concrete_heap = copy.deepcopy(concrete_heap)
        self.abstract_heap = copy.deepcopy(abstract_heap)

    def merge_state(
        self,
        concrete_heap: ConcreteHeap,
        abstract_heap: AbstractHeap,
        stack_frame: StackFrame,
        file_path: str,
    ) -> None:
        # First, assert that the stacks have the same shape
        self.cur_stack_frame.merge(stack_frame)
        self.concrete_heap.merge(concrete_heap)
        self.abstract_heap.merge(abstract_heap)

    def overwrite_variable_for_conditional(
        self, var_name: str, values: OrderedSet[Type]
    ) -> None:
        Update.add(self, var_name, values, overwrite=True, for_path_conditional=True)

    # endregion

    # region Summarization
    def get_strategy_from_agent(
        self,
        mode: str,
        code_window: str,
        loop_id: str,
        loop_body: str,
        loop_iteration: int,
        changed_allocation_sites: list[str],
        changed_variables: list[str],
        use_agent=True
    ) -> None | str:
        if mode == "manual":
            agent = HumanCLIAgent()
        elif mode == "llm":
            agent = LLMAgent(
                self.openai_client.api_key,
                self.openai_client.base_url,
                model=self.model,
                use_agent=use_agent
            )
        else:
            raise Exception(f"Unknown agent mode {mode}")
        # PHASE 1: site selection
        selected_sites = []
        first_iteration = True
        if len(changed_allocation_sites) > 0:
            logger.info("starting!")
            while True:
                action = agent.decide_site_selection(
                    self,
                    changed_allocation_sites,
                    code=code_window,
                    loop_body=loop_body,
                    loop_iteration=loop_iteration,
                    print_all_info=first_iteration,
                )
                first_iteration = False
                if action.type == "exit":
                    break
                if action.type == "select":
                    selected_sites = action.args["sites"]
                    break
                if action.type == "info":
                    handle_info_action(self, action, agent)
                elif action.type == "continue":
                    continue
                if action.type == "execute":
                    logger.info(f"agent wants to execute the loop again")
                    return "execute"

        logger.info(
            f"Selected sites are: {','.join([self.get_readable_allocation_site(site) for site in selected_sites])}"
        )

        # PHASE 2: strategy per selected site
        for allocation_site_id in selected_sites:
            if allocation_site_id not in self.allocation_sites:
                continue
            allocation_site = self.allocation_sites[allocation_site_id]
            merge_strategy_str = None
            widen_strategy_str = None
            while True:
                action = agent.decide_merging_strategy_for_site(
                    self, allocation_site_id, code_window, loop_body
                )
                if action.type == "strategy":
                    merge_strategy_info = action.args
                    break
                if action.type == "info":
                    handle_info_action(action)
                    continue
            while True:
                action = agent.decide_widening_strategy_for_site(
                    self, allocation_site_id, code_window, loop_body
                )
                if action.type == "strategy":
                    widen_strategy_info = action.args
                    break
                if action.type == "info":
                    handle_info_action(action)
                    continue

            # At this point both merge_strategy_str and widen_strategy_str have been selected
            merge_strategy: MergeStrategy = generate_strategy_for_site(
                merge_strategy_info, widen_strategy_info
            )
            allocation_site["merging_strategies"][loop_id] = merge_strategy
        changed_primitive_vars = [
            var for var in changed_variables if self.is_primitive_variable(var)
        ]
        if len(changed_primitive_vars) > 0:
            # for var_name in changed_primitive_vars:
            #    if loop_id in self.loop_primitives_to_summarize:
            #        self.loop_primitives_to_summarize[loop_id].add(var_name)
            #    else:
            #        self.loop_primitives_to_summarize[loop_id] = OrderedSet([var_name])
            # return
            # Now we decide which primitives to summarize
            while True:
                action = agent.decide_primitives(
                    self,
                    changed_variables,
                    code=code_window,
                    loop_body=loop_body,
                )
                if action.type == "select":
                    var_names_to_abstract = action.args["variables"]
                    for var_name in var_names_to_abstract:
                        if loop_id in self.loop_primitives_to_summarize:
                            self.loop_primitives_to_summarize[loop_id].add(var_name)
                        else:
                            self.loop_primitives_to_summarize[loop_id] = OrderedSet(
                                [var_name]
                            )
                    break
                if action.type == "info":
                    handle_info_action(action, agent)

    def get_strategy_from_agent_for_global_allocation_site(
        self, method: str, code: str, allocation_site_id: str
    ) -> None:
        if method == "manual":
            agent = HumanCLIAgent()
        elif method == "llm":
            agent = LLMAgent(
                self.openai_client.api_key,
                self.openai_client.base_url,
                model=self.model,
            )
        else:
            raise Exception(f"Unknown agent mode {method}")

        # PHASE 2: strategy per selected site
        allocation_site = self.allocation_sites[allocation_site_id]
        while True:
            action = agent.decide_merging_strategy_for_site(
                self, allocation_site_id, code=code, loop_body=None
            )
            if action.type == "strategy":
                merge_strategy_info = action.args
                break
            if action.type == "info":
                handle_info_action(action)
                continue
            if action.type == "exit":
                return
        while True:
            action = agent.decide_widening_strategy_for_site(
                self, allocation_site_id, code=code, loop_body=None
            )
            if action.type == "strategy":
                widen_strategy_info = action.args
                break
            if action.type == "info":
                handle_info_action(action)
                continue

        # At this point both merge_strategy_str and widen_strategy_str have been selected
        #print(merge_strategy_info, widen_strategy_info)
        merge_strategy: MergeStrategy = generate_strategy_for_site(
            merge_strategy_info, widen_strategy_info
        )
        allocation_site["global_merging_strategy"] = merge_strategy

    def get_strategy_from_agent_for_global_primitives(
        self, method: str, code: str, changed_variables: str
    ) -> None:
        # Now we decide which primitives to summarize
        if method == "manual":
            agent = HumanCLIAgent()
        elif method == "llm":
            agent = LLMAgent(
                self.openai_client.api_key,
                self.openai_client.base_url,
                model=self.model,
            )
        else:
            raise Exception(f"Unknown agent mode {method}")
        action = agent.decide_primitives(
            self, changed_variables, code=code, loop_body=None
        )
        if action.type == "select":
            var_names_to_abstract = action.args["variables"]
            for var_name in var_names_to_abstract:
                self.global_primitives_to_summarize.add(var_name)

        if action.type == "info":
            handle_info_action(action, agent)

    # Generate merging strategies for each of the changed allocation sites.
    def generate_merging_strategies(
        self,
        method: str,
        code_window: str,
        loop_body: str,
        loop_id: str,
        loop_iteration: int,
        changed_variables: list[str],
        changed_allocation_sites: list[str],
        use_agent=True
    ) -> None:
        if method == "allocation_sites":
            for allocation_site_id in changed_allocation_sites:
                logger.info(
                    f"Generating merging strategy for allocation site {allocation_site_id}"
                )
                allocation_site = self.allocation_sites[allocation_site_id]
                widening_policy = WidenAll()
                merging_strategy = MergeByAllocationSite(
                    widening_policy=widening_policy
                )
                allocation_site["merging_strategies"][loop_id] = merging_strategy
            for var_name in changed_variables:
                if loop_id in self.loop_primitives_to_summarize:
                    self.loop_primitives_to_summarize[loop_id].add(var_name)
                else:
                    self.loop_primitives_to_summarize[loop_id] = OrderedSet([var_name])
        if method == "recency":
            for allocation_site_id in changed_allocation_sites:
                allocation_site = self.allocation_sites[allocation_site_id]
                widening_policy = WidenAll()
                allocation_site["merging_strategies"][loop_id] = MergeByRecency(
                    widening_policy=widening_policy
                )
            for var_name in changed_variables:
                if loop_id in self.loop_primitives_to_summarize:
                    self.loop_primitives_to_summarize[loop_id].add(var_name)
                else:
                    self.loop_primitives_to_summarize[loop_id] = OrderedSet([var_name])
        if method == "depth":
            for allocation_site_id in changed_allocation_sites:
                allocation_site = self.allocation_sites[allocation_site_id]
                widening_policy = WidenDepthLimited(1)
                allocation_site["merging_strategies"][loop_id] = MergeByRecency(
                    widening_policy=widening_policy
                )
            for var_name in changed_variables:
                if loop_id in self.loop_primitives_to_summarize:
                    self.loop_primitives_to_summarize[loop_id].add(var_name)
                else:
                    self.loop_primitives_to_summarize[loop_id] = OrderedSet([var_name])
        if method == "manual" or method == "llm":
            self.garbage_collect()
            self.summarized_loop_ids.add(loop_id)
            strategy = self.get_strategy_from_agent(
                mode=method,
                code_window=code_window,
                loop_id=loop_id,
                loop_body=loop_body,
                loop_iteration=loop_iteration,
                changed_allocation_sites=changed_allocation_sites,
                changed_variables=changed_variables,
                use_agent=use_agent,
            )

            return strategy  # Most of the times it's none, but it could be "execute" to execute the loop again

    def generate_global_merging_strategies(
        self,
        method: str,
        code: str,
        changed_variables: list[str],
        global_allocation_sites: list[str],
        num_iterations: int,
    ) -> None:
        if method == "manual" or method == "llm":
            logger.info(
                f"generating global merging strategies for {global_allocation_sites}"
            )
            for allocation_site_id in global_allocation_sites:
                if num_iterations == 0 or num_iterations % 5 == 0:
                    self.get_strategy_from_agent_for_global_allocation_site(
                        method=method, code=code, allocation_site_id=allocation_site_id
                    )
            changed_primitive_vars = [
                var for var in changed_variables if self.is_primitive_variable(var)
            ]
            if len(changed_primitive_vars) > 0:
                self.get_strategy_from_agent_for_global_primitives(
                    method=method, code=code, changed_variables=changed_variables
                )

    def simplify_global(
        self,
        method: str,
        changed_variables: list[str],
        changed_allocation_sites: list[str],
    ) -> None:
        """
        Simplify the global state of the environment.
        This is a placeholder for now, but we can use it to simplify the global state
        of the environment.
        """
        if method == "manual" or method == "llm":
            for allocation_site_id in changed_allocation_sites:
                allocation_site = self.allocation_sites[allocation_site_id]
                merge_strategy: MergeStrategy = allocation_site[
                    "global_merging_strategy"
                ]
                if merge_strategy:
                    merge_strategy.merge(self, allocation_site_id)
            for var_name in self.global_primitives_to_summarize:
                self.lookup(var_name).merge_primitives()

    # Once we have a merging strategy for the allocation sites, we can
    def simplify(
        self,
        loop_id: str,
        changed_allocation_sites: list[str] = [],
    ) -> list[Address]:  # ignore the return type for now
        logger.info(
            f"changed allocation sites: {[self.get_readable_allocation_site(allocation_site_id) for allocation_site_id in changed_allocation_sites]}"
        )
        for allocation_site_id in changed_allocation_sites:
            allocation_site = self.allocation_sites[allocation_site_id]
            if loop_id not in allocation_site["merging_strategies"]:
                continue  # LLM decided not to abstract this one
            merge_strategy: MergeStrategy = allocation_site["merging_strategies"][
                loop_id
            ]
            merge_strategy.merge(self, allocation_site_id)
        if loop_id in self.loop_primitives_to_summarize:
            # logger.info(
            #    f"Loop id {loop_id} in allocation site has variables {self.loop_primitives_to_summarize[loop_id]}"
            # )
            for var_name in self.loop_primitives_to_summarize[loop_id]:
                self.lookup(var_name).merge_primitives()
        """
        if method == "llm":
            LLMAbstractions.summarize_llm_loop(
                self,
                model=model,
                changed_variables=changed_variables,
            )
        elif method == "allocation_sites":
            AllocationSiteAbstractions.merge_primitives_and_allocation_sites(
                self, changed_variables, changed_allocation_sites
            )
        elif method == "None":
            pass
        elif method == "llm_naive":
            return LLMNaiveAbstractions.simplify_llm_naive(
                self,
                model=model,
                code_window=code_window,
                loop_body=loop_body,
                changed_variables=changed_variables,
            )
        """
        self.garbage_collect()
        return []

    def simplify_fields_for_address(self, address: Address, fields_to_abstract) -> None:
        summarization.simplify_fields_for_address(self, address, fields_to_abstract)

    def simplify_variable_global(
        self, var_name: str, code: str, model: str, invoke_llm=False
    ) -> None:
        LLMAbstractions.simplify_variable_global(
            self, var_name, code, model, invoke_llm
        )

    def merge_all_addresses_in_record_result(
        self, record_result: RecordResult, simplify: bool = True
    ) -> Address:
        return summarization.merge_all_addresses_in_record_result(
            self, record_result, simplify
        )

    def merge_functions_with_same_allocation_site(
        self, functions: list[Address]
    ) -> None:
        summarization.merge_functions_with_same_allocation_site(self, functions)

    def is_local(self, name: str) -> bool:
        return summarization.is_local(self, name)

    def simplify_all_allocation_sites(self, simplify: bool = True) -> None:
        for allocation_site_id in self.allocation_sites:
            allocation_site = self.allocation_sites[allocation_site_id]
            widening_policy = WidenAll()
            merging_strategy = MergeByAllocationSite(widening_policy=widening_policy)
            allocation_site["global_merging_strategy"] = merging_strategy
            merging_strategy.merge(self, allocation_site_id)
        # AllocationSiteAbstractions.simplify_all_allocation_sites(self, simplify)

    def simplify_all_function_allocation_sites(self) -> None:
        for allocation_site_id in self.allocation_sites:
            allocation_site = self.allocation_sites[allocation_site_id]
            if allocation_site["type"] != "function":
                continue
            widening_policy = WidenAll()
            merging_strategy = MergeByAllocationSite(widening_policy=widening_policy)
            allocation_site["global_merging_strategy"] = merging_strategy
            logger.info(
                f"Simplifying function allocation site {allocation_site_id} with type {allocation_site['type']}"
            )
            merging_strategy.merge(self, allocation_site_id)

    def allocation_site_exists(self, allocation_site_id: str) -> bool:
        return allocation_site_id in self.allocation_sites

    def get_allocation_site_by_id(self, allocation_site_id: str) -> AllocationSite:

        return self.allocation_sites[allocation_site_id]

    def simplify_all_objects_with_recency(self) -> None:
        for allocation_site_id in self.allocation_sites:
            allocation_site = self.allocation_sites[allocation_site_id]
            if allocation_site["type"] != "object":
                continue
            widening_policy = WidenNone()
            merging_strategy = MergeByAllocationSite(widening_policy=widening_policy)
            allocation_site["global_merging_strategy"] = merging_strategy
            merging_strategy.merge(self, allocation_site_id)

    def simplify_address(self, address: Address, seen: list = None) -> None:
        summarization.simplify_address(self, address, seen)

    def simplify_heap_frame(self, heap_frame: Address) -> None:
        summarization.simplify_heap_frame(self, heap_frame)

    def abstract_addresses(
        self, addresses: list[Address] | OrderedSet[Address], abstract_address: Address
    ):
        summarization.abstract_addresses(self, addresses, abstract_address)

    def validate_address_abstraction(
        self, address: Address, address_abstraction: dict
    ) -> bool:
        return summarization.validate_address_abstraction(
            self, address, address_abstraction
        )

    def validate_all_variable_abstractions(self, abstractions: dict) -> bool:
        return summarization.validate_all_variable_abstractions(self, abstractions)

    def validate_variable_abstraction(
        self, var_name: str, abstraction_values: list
    ) -> bool:
        return summarization.validate_variable_abstraction(
            self, var_name, abstraction_values
        )

    def join_addresses(self, addr1, addr2, add_null: bool = True) -> bool:
        return summarization.join_addresses(self, addr1, addr2, add_null)

    # endregion

    # region Snapshotting

    def snapshot(self) -> dict:
        return EnvUtils.snapshot(self)

    def allocation_sites_snapshot(self, only_abstract=False) -> dict:
        return EnvUtils.allocation_sites_snapshot(self, only_abstract)

    def changed_allocation_sites_from_snapshot(
        self, allocation_sites_snapshot: dict, log: bool = True
    ) -> list:
        return EnvUtils.changed_allocation_sites_from_snapshot(
            self, allocation_sites_snapshot, log
        )

    def changed_object_allocation_sites_from_snapshot(
        self, previous_allocation_sites_snapshot: dict, loop_id: str
    ) -> list:
        return EnvUtils.get_changed_object_allocation_sites_from_snapshot(
            self, previous_allocation_sites_snapshot, loop_id
        )

    def changed_from_snapshot(self, snapshot: dict, log: bool = True) -> list:
        return EnvUtils.changed_from_snapshot(self, snapshot, log)

    def changed_primitives_from_snapshot(self, snapshot: dict) -> list:
        return EnvUtils.changed_from_snapshot(self, snapshot, primitives_only=True)

    # endregion

    # region Functions
    def get_return_values(
        self,
    ) -> list[Type]:  # Handle recursive calls, we can backfill them in here
        return functions.get_return_values(self)

    def get_raw_return_values(self) -> list[Type]:
        return self.cur_stack_frame.get_return_values()

    def add_return_value(self, return_value: Type) -> None:
        self.cur_stack_frame.add_return_value(return_value)

    def get_memoized_function(
        self, func: Address, caller_params: dict
    ) -> MemoizedFunction:
        return functions.get_memoized_function(self, func, caller_params)

    def is_recursive_call(self, func: Address, caller_params: dict) -> bool:
        return functions.is_recursive_call(self, func, caller_params)

    def add_recursive_call(self, func: Address, caller_params: dict) -> None:
        functions.add_recursive_call(self, func, caller_params)

    def memoize_and_return_from_function(
        self,
        func: Address,
        params_to_memoize: dict[str, list[Type]],
        function_file_path: str,
        call_site_file_path: str,
    ) -> None:
        functions.memoize_and_return_from_function(
            self, func, params_to_memoize, function_file_path, call_site_file_path
        )

    def return_from_function(
        env: "Environment", func: Address | None, file_path: str
    ) -> None:
        functions.return_from_function(env, func, file_path)

    # Returns from schema/stacks initialized for things like block scope

    def return_from_schema_non_function(self, file_path: str) -> None:
        file_path = convert_path_to_underscore(file_path)
        touched_address_for_stale_stack_frame = functions.get_touched_addresses(self)
        stale_return_values = self.get_raw_return_values()
        has_recursive_call = self.cur_stack_frame.has_recursive_call
        self.stack[file_path].pop()
        self.cur_stack_frame = self.stack[file_path][-1]
        for address in touched_address_for_stale_stack_frame:
            functions.add_touched_address(self, address)
        for return_value in stale_return_values:
            self.add_return_value(return_value)
        functions.set_has_recursive_call(self, has_recursive_call)

    def set_has_recursive_call(self, has_recursive_call: bool) -> None:
        functions.set_has_recursive_call(self, has_recursive_call)

    # endregion

    # region Builtins
    def initialize_builtins(self) -> None:
        self.initialize_dom_builtin()

    def is_builtin(self, value: Type) -> bool:
        if isinstance(value, Address):
            return self.get_object_type(value) == "builtin_function"
        return False

    def execute_builtin(self, obj: dict, args: OrderedSet[Type]) -> OrderedSet[Type]:
        return document.execute_builtin(self, obj, args)

    def add_builtin_function(
        self, name: str, params: list[str], rest: str = None
    ) -> Address:
        return document.add_builtin_function(self, name, params, rest)

    # endregion

    def initialize_dom_builtin(self) -> None:
        document.initialize_dom_builtin(self)

    def add_object_prototype(self) -> None:
        document.add_object_prototype(self)

    # endregion

    # region Utility Functions
    def check_no_pointers_from_abstract_to_concrete(self) -> None:
        EnvUtils.check_no_pointers_from_abstract_to_concrete(self)

    def in_subroutine(self, file_path: str) -> bool:
        return EnvUtils.in_subroutine(self, file_path)

    def is_loop_id_summarized(self, loop_id: str) -> bool:
        return loop_id in self.summarized_loop_ids

    def is_primitive_variable(self, var_name: str) -> bool:
        var_record_result = self.lookup(var_name)
        return var_record_result.contains_primitives()

    def get_meta(self, address: Address) -> dict:
        return self.lookup_address(address)["__meta__"]

    def get_parent_heap_frame(self, address: Address) -> Address:
        obj = self.lookup_address(address)
        if "__parent__" not in obj["__meta__"]:
            return None
        return obj["__meta__"]["__parent__"]

    def get_readable_allocation_site(self, allocation_site: str) -> str:
        return self.allocation_sites[allocation_site]["readable_allocation_site"]

    def set_parent_heap_frame(self, address: Address, parent: Address) -> None:
        obj = self.lookup_address(address)
        obj["__meta__"]["__parent__"] = parent

    def get_exports_for_module(self, module: str) -> RecordResult:
        return self.exports_for_modules[module]

    def change_module(self, module: str) -> None:
        module = convert_path_to_underscore(module)
        self.cur_stack_frame = self.stack[module][-1]

    def add_exported_identifier(self, file_path: str, identifier: str) -> None:
        EnvUtils.add_exported_identifier(self, file_path, identifier)

    def add_import(self, file_path: str, identifier: str) -> None:
        EnvUtils.add_import(self, file_path, identifier)

    # simple mark and sweep garbage collection
    def garbage_collect(self, expr_loc: str = None, ignore_memoized=False) -> None:
        EnvUtils.garbage_collect(self, expr_loc, ignore_memoized)

    def get_current_heap_frame(self) -> Address:
        return self.cur_stack_frame.get_heap_frame_address()

    def move_all_objects_to_abstract_heap(self) -> None:
        EnvUtils.move_all_objects_to_abstract_heap(self)

    # Need to do this recursively to make sure there's no pointers from the abstract heap to the concrete heap
    def move_object_to_abstract_heap(
        self, concrete_address: Address, count=0, seen=None
    ) -> None:
        EnvUtils.move_object_to_abstract_heap(self, concrete_address, count, seen)

    def value_contained_in_record_result(
        self, value_as_type: Type, record_result: RecordResult
    ) -> bool:
        return EnvUtils.value_contained_in_record_result(
            self, value_as_type, record_result
        )

    def get_all_variable_names(self) -> list[str]:
        return EnvUtils.get_all_variable_names(self)

    def get_all_variable_values(self) -> dict[str, list[Type]]:
        return EnvUtils.get_all_variable_values(self)

    def update_variable_state(self) -> None:
        all_reachable_variable_names = EnvUtils.get_all_reachable_variable_names(self)
        self.variable_state = EnvUtils.get_all_reachable_object_variable_values(
            self, all_reachable_variable_names
        )

    # Gets all record results that are pointed to by variables.
    # get_all_variable_names doesn't handle same variable names across different scopes
    def get_all_variable_record_results(self) -> list[RecordResult]:
        return EnvUtils.get_all_variable_record_results(self)

    def get_reachable_allocation_sites(
        self, variable_names: OrderedSet[str], ignore_functions=False
    ) -> dict[str, list[Type]]:
        return EnvUtils.get_reachable_allocation_sites(
            self, variable_names, ignore_functions=ignore_functions
        )

    def get_all_reachable_object_variable_names(self) -> OrderedSet[str]:
        return EnvUtils.get_all_reachable_object_variable_names(self)

    def get_all_reachable_object_variable_values(
        self, variable_names: OrderedSet[str]
    ) -> dict[str, list[Type]]:
        return EnvUtils.get_all_reachable_object_variable_values(self, variable_names)

    def get_all_reachable_variable_names(self) -> list[str]:
        return EnvUtils.get_all_reachable_variable_names(self)

    def value_to_type(
        self, value: object, abstract: bool = False, allocation_site: str = None
    ) -> Type:
        return EnvUtils.value_to_type(
            self, value, abstract, allocation_site=allocation_site
        )

    def type_to_value(self, type: Type, seen: list = None, log=False) -> object:
        return EnvUtils.type_to_value(self, type, seen, log)

    def list_only_contains_primitives(self, values: list[Type]) -> bool:
        return all(
            isinstance(value, Primitive) or isinstance(value, AbstractType)
            for value in values
        )

    def list_only_contains_addresses(self, values: list[Type]) -> bool:
        return all(isinstance(value, Address) for value in values)

    def contains_only_functions(self, values: list[Type]) -> bool:
        return all(
            isinstance(value, Address) and self.get_object_type(value) == "function"
            for value in values
        )

    def record_result_contains_objects(self, record_result: RecordResult) -> bool:
        return any(
            isinstance(value, Address) and self.is_object(value)
            for value in record_result.get_all_values()
        )

    # Returns whether an object on the heap represents a normal object, function, or class
    def get_object_type(self, address: Address) -> str:
        obj = self.lookup_address(address)
        if "__meta__" in obj:
            return obj["__meta__"]["type"]
        else:
            raise Exception(f"__meta__ not found in {obj}")

    def get_allocation_site(self, address: Address) -> str:
        obj = self.lookup_address(address)
        if "__meta__" in obj:
            if "allocation_site" not in obj["__meta__"]:
                return None  # builtins don't have an allocation site and don't need to be merged
            return obj["__meta__"]["allocation_site"]
        else:
            raise Exception(f"__meta__ not found in {obj}")

    def is_function(self, value: Type) -> bool:
        if isinstance(value, Address):
            return self.get_object_type(value) == "function"
        return False

    def is_heap_frame(self, value: Type) -> bool:
        if isinstance(value, Address):
            return self.get_object_type(value) == "heap_frame"
        return False

    def is_object(self, value: Type) -> bool:
        if isinstance(value, Address):
            return self.get_object_type(value) == "object"
        return False

    def is_class(self, value: Type) -> bool:
        if isinstance(value, Address):
            return self.get_object_type(value) == "class"
        return False

    # endregion

    # region Printing

    def pretty_print(self, derive: bool = False) -> str:
        return PrettyPrint.pretty_print(self, derive)

    def pretty_print_variables(self, variable_names: list[str]) -> str:
        return PrettyPrint.pretty_print_variables(self, variable_names)

    def pretty_print_allocation_sites(self) -> str:
        return PrettyPrint.pretty_print_allocation_sites(self)

    def pretty_print_all_variables(self) -> str:
        return PrettyPrint.pretty_print_all_variables(self)

    def pretty_print_primitives(self) -> str:
        return PrettyPrint.pretty_print_primitives(self)

    # partial doesn't print out heap frames
    def pretty_print_concrete_heap(self, partial: bool = True) -> str:
        return PrettyPrint.pretty_print_concrete_heap(self, partial)

    # Print a list of addresses, ignore heap frames
    def pretty_print_addresses(self, addresses: list[Address]) -> str:
        return PrettyPrint.pretty_print_addresses(self, addresses)

    def pretty_print_abstract_heap(self, partial: bool = True) -> str:
        return PrettyPrint.pretty_print_abstract_heap(self, partial)

    def visualize(self, prev_statement: str = "", cur_statement: str = "") -> None:
        Visualize.visualize(prev_statement, cur_statement, self)
