"""A lot of bookkeeping stuff, like initializing schemas, what to do after finishing an iteration, etc."""

from typing import TYPE_CHECKING, Union
from absint_ai.Environment.memory.ConcreteHeap import ConcreteHeap
from absint_ai.Environment.memory.RecordResult import RecordResult
from absint_ai.Environment.memory.StackFrame import StackFrame
from absint_ai.Environment.types.Type import *
from absint_ai.utils.Util import *
from openai import OpenAI

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


def add_schema(env: "Environment", schema: dict[str, dict[str, list]]) -> None:
    if env.scope_schema is None:
        env.scope_schema = schema
    else:
        for key, value in schema.items():
            if key == "global":
                env.scope_schema["global"]["private_vars"].extend(value["private_vars"])
                env.scope_schema["global"]["shared_vars"].extend(value["shared_vars"])
            else:
                env.scope_schema[key] = value


# Clears everything except for abstract heap and allocation sites
def finish_iteration(env: "Environment", simplify: bool = True) -> None:
    for module_name in env.stack:
        if len(env.stack[module_name]):
            env.stack[module_name] = [env.stack[module_name][0]]
            stack_frame_for_module = env.stack[module_name][0]
            if simplify:
                for var_name in stack_frame_for_module.get_variable_names():
                    if (
                        len(
                            stack_frame_for_module.get_variable(
                                var_name
                            ).get_all_primitives()
                        )
                        > 0
                    ):
                        stack_frame_for_module.get_variable(var_name).merge_primitives()
                    heap_frame = stack_frame_for_module.get_heap_frame_address()
                    env.simplify_heap_frame(heap_frame)

    env.concrete_heap = ConcreteHeap()
    env.memoized_functions = []
    env.first_iteration = False
    for allocation_site in env.allocation_sites:
        env.allocation_sites[allocation_site]["concrete_addresses"] = OrderedSet()


def initialize_openai(env: "Environment", api_key: str, base_url: str) -> None:
    env.api_key = api_key
    env.base_url = base_url
    env.openai_client = OpenAI(api_key=api_key, base_url=base_url)  # type: ignore


def initialize_model(env: "Environment", model: str) -> None:
    env.model = model


def get_stack_depth(env: "Environment", file_path: str) -> int:
    path_underscore = convert_path_to_underscore(file_path)
    if path_underscore in env.stack:
        return len(env.stack[path_underscore])
    return 0


def initialize_from_schema(
    env: "Environment",
    schema: dict[str, object],
    file_path: str,
    allocation_site: str,
    parent_address: Address | None,
    this_address: Address = None,  # type: ignore
    abstract_parents: bool = False,
    is_module: bool = False,
    is_function: bool = True,  # Whether the schema being initialized is a function or block. Used to determine whether we need to move things to the abstract heap
) -> None:
    path_underscore = convert_path_to_underscore(file_path)
    if is_module:
        if path_underscore in env.stack and len(env.stack[path_underscore]) > 0:
            env.cur_stack_frame = env.stack[path_underscore][-1]
            return

    stack_frame_vars = {
        var_name: RecordResult("local") for var_name in schema["private_vars"]  # type: ignore
    }
    heap_frame_vars: dict[str, Union[dict, RecordResult]] = {
        var_name: RecordResult("local") for var_name in schema["shared_vars"]  # type: ignore
    }
    if is_module:
        exports_record_result = RecordResult("global")
        heap_frame_vars["exports"] = exports_record_result
        heap_frame_vars["module"] = RecordResult("global")
        # By default we initialize the global object on the abstract heap. All other schemas are initialized on the concrete heap
        if file_path == "global":
            exports_address = env.add_object_to_abstract_heap(
                {},
                allocation_site=None,
                add_proto=False,
                allocation_site_type="heap_frame",
            )
        else:
            exports_address = env.add_object_to_heap(
                {},
                allocation_site=None,
                add_proto=False,
                allocation_site_type="heap_frame",
            )
        exports_record_result.add_value(exports_address)
        if file_path == "global":
            modules_address = env.add_object_to_abstract_heap(
                {"exports": exports_record_result},
                allocation_site=None,
                add_proto=False,
            )
        else:
            modules_address = env.add_object_to_heap(
                {"exports": exports_record_result},
                allocation_site=None,
                add_proto=False,
            )
        env.exports_for_modules[file_path] = exports_record_result
        heap_frame_vars["module"].add_value(modules_address)
    heap_frame_vars["__meta__"] = {}
    heap_frame_vars["__meta__"]["type"] = "heap_frame"
    heap_frame_vars["__meta__"]["__parent__"] = parent_address
    if abstract_parents:
        if parent_address:
            env.move_object_to_abstract_heap(parent_address)

    if file_path == "global":
        heap_frame_addr = env.add_object_to_abstract_heap(
            heap_frame_vars,
            allocation_site=allocation_site,
            add_proto=False,
            allocation_site_type="heap_frame",
        )
    else:
        heap_frame_addr = env.add_object_to_heap(
            heap_frame_vars,
            allocation_site=allocation_site,
            add_proto=False,
            allocation_site_type="heap_frame",
        )
    stack_frame = StackFrame(heap_frame_addr, stack_frame_vars, is_function=is_function)
    stack_frame.add_allocated_address(heap_frame_addr)
    if this_address:
        stack_frame.add_variable("this", this_address)
        this_obj = env.lookup_address(this_address)
        if "prototype" in this_obj:
            stack_frame.add_variable("super", this_obj["prototype"])

    if path_underscore not in env.stack:
        env.stack[path_underscore] = []
    env.stack[path_underscore].append(
        stack_frame
    )  # TODO We should be appending to a different stack frame if the function is imported from another module
    env.cur_stack_frame = stack_frame
