from absint_ai.Environment.memory.RecordResult import RecordResult
from absint_ai.Environment.types.Type import *
from absint_ai.utils.Util import *
from absint_ai.Environment.EnvUtils.AllocationSite import AllocationSite
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


def update_global_env(env: "Environment", global_schema: dict[str, object]) -> None:
    if not env.initialized_global:
        env.initialize_from_schema(
            global_schema,
            file_path="global",
            allocation_site="global",
            parent_address=None,
            is_function=False,
        )
        env.initialized_global = True
        env.global_stack_frame = env.stack["global"][0]
        env.initialize_builtins()
    else:
        # Need to add new global variables to the global environment
        global_stack_frame = env.stack["global"][0]
        global_heap_frame = env.lookup_address(
            global_stack_frame.get_heap_frame_address()
        )
        for var_name in global_schema["private_vars"]:
            if var_name not in global_stack_frame:
                global_stack_frame[var_name] = RecordResult("local")
        for var_name in global_schema["shared_vars"]:
            if var_name not in global_heap_frame:
                global_heap_frame[var_name] = RecordResult("local")


# endregion

# region Adding/Updating


def add(
    env: "Environment",
    var_name: str,
    value: object,
    kind: str = "var",
    overwrite: bool = False,
    is_abstraction: bool = False,
    for_path_conditional: bool = False,
) -> None:
    var_record_result = env.lookup(var_name)
    if isinstance(value, Type):
        if value in var_record_result.get_all_values():
            if isinstance(value, Address):
                if value.get_addr_type() == "concrete":
                    pass
            return
        if overwrite and (
            is_abstraction
            or var_record_result.get_type() == "local"
            or for_path_conditional
        ):
            var_record_result.clear_values()
        if isinstance(value, Address) and var_record_result.get_type() == "global":
            env.move_object_to_abstract_heap(value)
        if isinstance(value, Address) and env.is_object(value):
            value_obj = env.lookup_and_derive_address(value, ignore_proto=True)
            var_values = [
                env.lookup_and_derive_address(_, ignore_proto=True)
                for _ in var_record_result.get_all_values()
                if isinstance(_, Address)
            ]
            if not item_contained_in_list(value_obj, var_values):
                var_record_result.add_value(value)
        elif not item_contained_in_list(value, var_record_result.get_all_values()):
            var_record_result.add_value(value)
    else:
        if type(value) == list or type(value) == OrderedSet or type(value) == set:
            if overwrite and (
                is_abstraction
                or var_record_result.get_type() == "local"
                or for_path_conditional
            ):
                var_record_result.clear_values()
            for val in value:
                env.add(var_name, val, kind=kind, overwrite=False)
        else:
            raise Exception(f"Unknown value type {value} being added")


def add_object_to_heap(
    env: "Environment",
    obj: dict,
    allocation_site: str | None,
    add_proto: bool = True,
    allocation_site_type: str = "object",
) -> Address:
    enumerable_values = OrderedSet(
        [_ for _ in obj.keys() if _ != "__meta__" and _ != "__proto__"]
    )
    if add_proto:
        if "__proto__" not in obj:
            obj["__proto__"] = RecordResult("local")
            obj["__proto__"].add_value(env.object_prototype)
    if "__meta__" not in obj:
        obj["__meta__"] = {}
        obj["__meta__"]["enumerable_values"] = enumerable_values
        obj["__meta__"]["allocation_site"] = allocation_site
        obj["__meta__"]["type"] = allocation_site_type
    if allocation_site:
        obj["__meta__"]["allocation_site"] = allocation_site
    heap_id = env.concrete_heap.add(obj)
    value_type = Address(heap_id, addr_type="concrete")
    if allocation_site:
        if allocation_site not in env.allocation_sites:
            env.allocation_sites[allocation_site] = AllocationSite(
                allocation_site_type=allocation_site_type,
                allocation_site=allocation_site,
            )
        # assert self.allocation_sites[allocation_site]["type"] == "object"
        env.allocation_sites[allocation_site]["concrete_addresses"].add(value_type)
        if value_type not in env.allocation_sites[allocation_site]["summary_addresses"]:
            env.allocation_sites[allocation_site]["most_recent_address"] = value_type

    if env.cur_stack_frame:
        env.cur_stack_frame.add_allocated_address(value_type)
    return value_type


def add_object_to_abstract_heap(
    env: "Environment",
    obj: dict,
    add_proto: bool = True,
    allocation_site=None,
    parent_address=None,
    allocation_site_type: str = "object",
    add_to_concrete: bool=True
) -> Address:
    enumerable_values = OrderedSet(
        [_ for _ in obj.keys() if _ != "__meta__" and _ != "__proto__"]
    )
    if add_proto:
        if "__proto__" not in obj:
            obj["__proto__"] = RecordResult("local")
            obj["__proto__"].add_value(env.object_prototype)
    if "__meta__" not in obj:
        obj["__meta__"] = {}
        obj["__meta__"]["enumerable_values"] = enumerable_values
        obj["__meta__"]["type"] = allocation_site_type
        if allocation_site:
            obj["__meta__"]["allocation_site"] = allocation_site
        if parent_address:
            obj["__meta__"]["__parent__"] = parent_address
    if allocation_site:
        obj["__meta__"]["allocation_site"] = allocation_site
    heap_id = env.abstract_heap.add(obj)
    value_type = Address(heap_id, addr_type="abstract")
    if allocation_site:
        if allocation_site not in env.allocation_sites:
            env.allocation_sites[allocation_site] = AllocationSite(
                allocation_site_type=allocation_site_type,
                allocation_site=allocation_site,
            )
        if add_to_concrete:
            env.allocation_sites[allocation_site]["concrete_addresses"].add(value_type)
        if value_type not in env.allocation_sites[allocation_site]["summary_addresses"]:
            env.allocation_sites[allocation_site]["summary_addresses"].add(value_type)
        env.allocation_sites[allocation_site]["most_recent_address"] = value_type
    if env.cur_stack_frame:
        env.cur_stack_frame.add_allocated_address(value_type)
    return value_type


# region Functions
# used for FunctionDeclarations
def add_function(
    env: "Environment",
    name: str,
    params: list[str],
    value: esprima.nodes.BlockStatement,
    scope_id: str,
    file_path: str,
    rest: str = None,
) -> None:  # expr type and loc used to get the schema for the function
    schema = env.scope_schema[scope_id]
    allocation_site = allocation_site_from_expr(value, file_path)
    env_stack_depth = env.get_stack_depth(file_path)
    if env_stack_depth == 1 and not env.first_iteration:
        # If this is not the first iteration, we should just re-use the functions in the previous iterations. It's global so we don't need to worry about closures.
        if env.allocation_site_exists(allocation_site):
            allocation_site_for_function = env.get_allocation_site_by_id(
                allocation_site
            )
            if len(allocation_site_for_function["summary_addresses"]) > 0:
                function_object_address = allocation_site_for_function[
                    "summary_addresses"
                ][0]
                env.add(name, function_object_address, overwrite=True)
                return
    logger.info(f"adding function {name} to stack depth {env_stack_depth}")
    prototype_address = add_object_to_heap(
        env, {}, allocation_site=None, add_proto=False
    )
    function_info = {
        "type": "function",
        "enumerable_values": OrderedSet(),
        "allocation_site": allocation_site,
        "params": params,
        "body": value,
        "schema": schema,
        "schema_id": scope_id,
        "file_path": file_path,
        "rest": rest,
        "__parent__": env.get_current_heap_frame(),
        "function_id": allocation_site_from_expr(value, file_path),
    }
    function_object = env.concrete_heap.add(
        {
            "prototype": [prototype_address],
            "__meta__": function_info,
        }
    )
    if allocation_site:
        logger.info(f"adding new function at allocation site {allocation_site}")
        if allocation_site not in env.allocation_sites:
            env.allocation_sites[allocation_site] = AllocationSite(
                allocation_site_type="function", allocation_site=allocation_site
            )
        assert env.allocation_sites[allocation_site]["type"] == "function"
        env.allocation_sites[allocation_site]["concrete_addresses"].add(
            Address(function_object, "concrete")
        )
        env.allocation_sites[allocation_site]["most_recent"] = Address(
            function_object, "concrete"
        )
    else:
        raise Exception(
            f"Allocation site not found for function {name} on line {value.loc.start.line}"
        )
    function_object_address = Address(function_object, "concrete")
    if env.cur_stack_frame:
        env.cur_stack_frame.add_allocated_address(function_object_address)
    add(env, name, function_object_address, overwrite=True)


def initialize_new_function(
    env: "Environment",
    name: str,
    params: list[str],
    value: esprima.nodes.BlockStatement,
    scope_id: str,
    file_path: str,
) -> Address:
    schema = env.scope_schema[scope_id]
    allocation_site = allocation_site_from_expr(value, file_path)
    env_stack_depth = env.get_stack_depth(file_path)
    if env_stack_depth == 1 and not env.first_iteration:
        # If this is not the first iteration, we should just re-use the functions in the previous iterations. It's global so we don't need to worry about closures.
        if env.allocation_site_exists(allocation_site):
            allocation_site_for_function = env.get_allocation_site_by_id(
                allocation_site
            )
            if len(allocation_site_for_function["summary_addresses"]) > 0:
                return allocation_site_for_function["summary_addresses"][0]
    logger.info(f"adding function {name} to stack depth {env_stack_depth}")
    prototype_address = env.add_object_to_heap(
        {}, allocation_site=None, add_proto=False
    )
    function_info = {
        "type": "function",
        "enumerable_values": OrderedSet(),
        "allocation_site": allocation_site,
        "params": params,
        "body": value,
        "schema": schema,
        "schema_id": scope_id,
        "file_path": file_path,
        "rest": None,
        "__parent__": env.get_current_heap_frame(),
        "function_id": allocation_site_from_expr(value, file_path),
    }
    function_object = env.concrete_heap.add(
        {
            "prototype": [prototype_address],
            "__meta__": function_info,
        }
    )

    if allocation_site:
        logger.info(f"adding new function at allocation site {allocation_site}")
        if allocation_site not in env.allocation_sites:
            # logger.info(f"Adding allocation site {allocation_site}")
            env.allocation_sites[allocation_site] = AllocationSite(
                allocation_site_type="function", allocation_site=allocation_site
            )
        assert env.allocation_sites[allocation_site]["type"] == "function"
        env.allocation_sites[allocation_site]["concrete_addresses"].add(
            Address(function_object, "concrete")
        )
        env.allocation_sites[allocation_site]["most_recent_address"] = Address(
            function_object, "concrete"
        )
    else:
        raise Exception(
            f"Allocation site not found for function {name} on line {value.loc.start.line}"
        )
    logger.info(f"adding function")
    function_object_address = Address(function_object, "concrete")
    if env.cur_stack_frame:
        env.cur_stack_frame.add_allocated_address(function_object_address)
    logger.info(f"function added")
    return function_object_address


# endregion
def update(
    env: "Environment",
    address: Address,
    possibleProps: list[Type],
    possibleValues: list[Type],
    overwrite: bool = False,
    objName: str | None = None,
) -> None:
    if address.get_addr_type() == "concrete":
        heap_id = address.get_value()
        for prop in possibleProps:
            if isinstance(prop, AbstractType):
                prop_key = prop
            else:
                prop_key = prop.get_value()
            env.concrete_heap.update(heap_id, prop_key, possibleValues, overwrite)  # type: ignore
    elif address.get_addr_type() == "abstract":
        heap_id = address.get_value()
        for value in possibleValues:
            if isinstance(value, Address):
                env.move_object_to_abstract_heap(value)
        for prop in possibleProps:
            if isinstance(prop, AbstractType):
                prop_key = prop
            else:
                prop_key = prop.get_value()
            env.abstract_heap.update(heap_id, prop_key, possibleValues)  # type: ignore


def add_value_for_field(
    env: "Environment", address: Address, field: str, value: Type, overwrite=False
) -> None:
    if address.get_addr_type() == "concrete":
        heap_id = address.get_value()
        env.concrete_heap.update(heap_id, field, [value], overwrite=overwrite)
    elif address.get_addr_type() == "abstract":
        heap_id = address.get_value()
        env.abstract_heap.update(heap_id, field, [value])


# def update_object_fields_for_conditional(env: "Environment", addresses: list[Address], fields: list[list[str]], value: Type) -> None:
#    for address in addresses:


def update_this(
    env: "Environment",
    possibleProps: list[Type],
    possibleValues: list[Type],
    overwrite: bool = False,
) -> None:
    if not env.lookup("this").get_all_values():
        raise Exception("No this address found!")
    for this_address in env.lookup("this").get_all_values():
        if not isinstance(this_address, Address):
            raise Exception(f"Expected address, got {this_address}")
        update(env, this_address, possibleProps, possibleValues, overwrite)
