from absint_ai.utils.Logger import logger
from typing import TYPE_CHECKING
from ordered_set import OrderedSet
import absint_ai.utils.Util as Util
import json
from absint_ai.Environment.types.Type import *

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import (
        Environment,
    )  # Import the class only for type checking


def simplify_all_allocation_sites(env: "Environment", simplify: bool = True) -> None:
    for allocation_site_id in env.allocation_sites:
        allocation_site = env.allocation_sites[allocation_site_id]
        if allocation_site["type"] == "object":
            merge_allocation_site_object(
                env, allocation_site_id, add_null=False, simplify=simplify
            )
        elif allocation_site["type"] == "function":
            merge_allocation_site_function(env, allocation_site_id)
        elif allocation_site["type"] == "heap_frame":
            merge_allocation_site_heap_frame(env, allocation_site_id, simplify=simplify)
        elif allocation_site["type"] == "class":
            merge_allocation_site_object(env, allocation_site_id)
        else:
            raise Exception(f"Unknown allocation site type {allocation_site['type']}")


def merge_allocation_site_class(env: "Environment", allocation_site_id: str) -> None:
    logger.info(f"Merging allocation site {allocation_site_id} as class")
    allocation_site = env.allocation_sites[allocation_site_id]
    should_simplify = False
    if len(allocation_site["summary_addresses"]) == 0:
        abstract_address = env.add_object_to_abstract_heap(
            {},
            add_proto=False,
            allocation_site=allocation_site_id,
            allocation_site_type="heap_frame",
        )
        allocation_site["summary_addresses"].add(abstract_address)
    else:
        abstract_address = allocation_site["summary_addresses"][
            0
        ]  # in the baseline (this class) there is only one abstract address per allocation site
    for concrete_address in allocation_site["concrete_addresses"]:
        assert env.get_object_type(concrete_address) == "class"
        if concrete_address == abstract_address:
            continue
        env.move_object_to_abstract_heap(concrete_address)


def merge_allocation_site_heap_frame(
    env: "Environment", allocation_site_id: str, simplify: bool = True
) -> None:
    allocation_site = env.allocation_sites[allocation_site_id]
    # logger.info(f"merging heap frame allocation site {allocation_site_id} {allocation_site}")
    should_simplify = False
    if len(allocation_site["summary_addresses"]) == 0:
        abstract_address = env.add_object_to_abstract_heap(
            {},
            add_proto=False,
            allocation_site=allocation_site_id,
            allocation_site_type="heap_frame",
        )
        allocation_site["summary_addresses"].add(abstract_address)
    else:
        abstract_address = allocation_site["summary_addresses"][
            0
        ]  # in the baseline (this class) there is only one abstract address per allocation site
    # If there are no concrete addresses but the abstract address is modified, we should still simplify the heap frame
    for concrete_address in allocation_site["concrete_addresses"]:
        assert env.get_object_type(concrete_address) == "heap_frame"
        if concrete_address == abstract_address:
            continue
        env.move_object_to_abstract_heap(concrete_address)
        modified = env.join_addresses(
            abstract_address, concrete_address, add_null=False
        )
        if modified:
            should_simplify = True
    env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
    allocation_site["concrete_addresses"] = OrderedSet()
    # FIXME If the abstract address is updated but there are no concrete addresses, should_simplify will be False even though it should be true since the abstract address was modified
    if simplify:
        env.simplify_heap_frame(abstract_address)


def merge_allocation_site_object(
    env: "Environment",
    allocation_site_id: str,
    add_null: bool = True,
    simplify: bool = True,
) -> None:
    allocation_site = env.allocation_sites[allocation_site_id]
    should_simplify = False
    if len(allocation_site["summary_addresses"]) == 0:
        abstract_address = env.add_object_to_abstract_heap(
            {}, add_proto=False, allocation_site=allocation_site_id
        )  # Because this is a new address, field_in_object will always return Nothing. So everything will be NULL.
        allocation_site["summary_addresses"].add(abstract_address)
    else:
        abstract_address = allocation_site["summary_addresses"][0]

    for concrete_address in allocation_site["concrete_addresses"]:
        if concrete_address == abstract_address:
            continue
        assert (
            env.get_object_type(concrete_address) == "object"
            or env.get_object_type(concrete_address) == "class"
        )
        # We should abstract this otherwise we could get pointers from the
        # abstract heap to the concrete heap after joining
        env.move_object_to_abstract_heap(concrete_address)

        modified = env.join_addresses(
            abstract_address, concrete_address, add_null=add_null
        )  # this is pretty simple. Just combines them.
        if modified:
            should_simplify = True
    env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
    allocation_site["concrete_addresses"] = OrderedSet()
    # FIXME If the abstract address is updated but there are no concrete addresses, should_simplify will be False even though it should be true since the abstract address was modified
    if simplify:
        env.simplify_address(abstract_address)


def merge_allocation_site_function(env: "Environment", allocation_site_id: str) -> None:
    allocation_site = env.allocation_sites[allocation_site_id]
    if (
        len(allocation_site["summary_addresses"]) == 0
        and len(allocation_site["concrete_addresses"]) == 1
    ):
        concrete_function = list(allocation_site["concrete_addresses"])[0]
        env.move_object_to_abstract_heap(concrete_function)
        abstract_function = concrete_function
        allocation_site["summary_addresses"].add(abstract_function)
        allocation_site["concrete_addresses"] = OrderedSet()
    if len(allocation_site["concrete_addresses"]) > 1 or (
        len(allocation_site["summary_addresses"]) > 0
        and len(allocation_site["concrete_addresses"]) > 0
    ):
        env.merge_functions_with_same_allocation_site(
            allocation_site["concrete_addresses"]
        )
        allocation_site["concrete_addresses"] = OrderedSet()


def merge_primitives_and_allocation_sites(
    env: "Environment",
    changed_variables: list[str],
    allocation_sites: list[str],
    simplify: bool = True,
) -> None:
    if simplify:
        for var_name in changed_variables:
            record_result = env.lookup(var_name)
            if record_result.only_contains_primitives():
                logger.info(f"{var_name} before merging: {record_result}")
                record_result.merge_primitives()
                logger.info(f"{var_name} after merging: {record_result}")
    for allocation_site in allocation_sites:
        if env.allocation_sites[allocation_site]["type"] == "object":
            merge_allocation_site_object(env, allocation_site, simplify=simplify)
        elif env.allocation_sites[allocation_site]["type"] == "function":
            merge_allocation_site_function(env, allocation_site)
        elif env.allocation_sites[allocation_site]["type"] == "heap_frame":
            continue
        else:
            raise Exception(
                f"Unknown allocation site type {env.allocation_sites[allocation_site]}"
            )
