from absint_ai.utils.Logger import logger
from absint_ai.Environment.agents.widening import WideningPolicy
from typing import TYPE_CHECKING
import absint_ai.utils.Util as Util
from ordered_set import OrderedSet
from abc import ABC, abstractmethod

from absint_ai.Environment.types.Type import *

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import (
        Environment,
    )  # Import the class only for type checking


"""Different type of merging rules 
- Merge by allocation site (most basic)
- Merge by recency (most recent allocation site). Keep the most recent object concrete
- Merge by role (if the value for a particular field is the same, merge them)
- Merge by all fields. If all the fields are the same, merge them
- Prototype merging. If the prototype is the same, merge them
"""


class MergeRule:
    def __init__(self, name: str, pattern: str, abstraction: str):
        self.name = name
        self.pattern = pattern  # e.g. 'alloc_site:L1'
        self.abstraction = abstraction  #  'site', 'recency', etc.


class MergeStrategy(ABC):
    @abstractmethod
    def merge(self, env: "Environment", allocation_site_id: str) -> None:
        """
        Merge the allocation site with the given ID using the merge strategy.
        """
        pass

    @abstractmethod
    def has_converged(
        self, env: "Environment", previous_values: dict, current_values: dict
    ) -> bool:
        """
        Check if the allocation site has converged.
        """
        pass


class MergeByAllocationSite(MergeStrategy):
    def __init__(self, widening_policy: WideningPolicy):
        self.widening_policy = widening_policy

    def merge(self, env: "Environment", allocation_site_id: str) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if allocation_site["type"] == "object":
            self.merge_allocation_site_object(env, allocation_site_id)
        elif allocation_site["type"] == "function":
            self.merge_allocation_site_function(env, allocation_site_id)
        elif allocation_site["type"] == "heap_frame":
            self.merge_allocation_site_heap_frame(env, allocation_site_id)
        elif allocation_site["type"] == "class":
            self.merge_allocation_site_object(env, allocation_site_id)
        else:
            raise Exception(f"Unknown allocation site type {allocation_site['type']}")

        # At this point, the allocation site has been merged, and there should only be one abstract address for the allocation side. We can now widen the address.
        if len(allocation_site["summary_addresses"]) > 0:
            self.widening_policy.widen(
                env, allocation_site["summary_addresses"][0]
            )  # type: ignore

    def merge_allocation_site_class(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        logger.info(
            f"Merging allocation site {env.get_readable_allocation_site(allocation_site_id)} as class"
        )

        allocation_site = env.allocation_sites[allocation_site_id]
        should_simplify = False

        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "class"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)

    def merge_allocation_site_heap_frame(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        # logger.info(f"merging heap frame allocation site {allocation_site_id} {allocation_site}")
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        # If there are no concrete addresses but the abstract address is modified, we should still simplify the heap frame
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "heap_frame"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)
            modified = env.join_addresses(
                abstract_address, concrete_address, add_null=False
            )
        env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
        allocation_site["concrete_addresses"] = OrderedSet()

    def merge_allocation_site_object(
        self,
        env: "Environment",
        allocation_site_id: str,
        add_null: bool = False,  # TODO I should implement better may/must pointer analysis.
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        should_simplify = False
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {}, add_proto=False, allocation_site=allocation_site_id
            )  # Because this is a new address, field_in_object will always return Nothing. So everything will be NULL.
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][0]

        for concrete_address in allocation_site["concrete_addresses"]:
            if concrete_address == abstract_address:
                continue
            assert (
                env.get_object_type(concrete_address) == "object"
                or env.get_object_type(concrete_address) == "class"
            )
            # We should abstract this otherwise we could get pointers from the
            # abstract heap to the concrete heap after joining
            env.move_object_to_abstract_heap(concrete_address)

            modified = env.join_addresses(
                abstract_address, concrete_address, add_null=add_null
            )  # this is pretty simple. Just combines them.
        env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
        allocation_site["concrete_addresses"] = OrderedSet()

    def merge_allocation_site_function(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        # logger.info(
        #    f"function {env.get_readable_allocation_site(allocation_site_id)} summary addresses: {len(allocation_site['summary_addresses'])} concrete addresses: {len(allocation_site['concrete_addresses'])}"
        # )
        if (
            len(allocation_site["summary_addresses"]) == 0
            and len(allocation_site["concrete_addresses"]) == 1
        ):
            concrete_function = list(allocation_site["concrete_addresses"])[0]
            env.move_object_to_abstract_heap(concrete_function)
            abstract_function = concrete_function
            allocation_site["summary_addresses"].add(abstract_function)
            allocation_site["concrete_addresses"] = OrderedSet()
        if len(allocation_site["concrete_addresses"]) > 1 or (
            len(allocation_site["summary_addresses"]) > 0
            and len(allocation_site["concrete_addresses"]) > 0
        ):
            env.merge_functions_with_same_allocation_site(
                allocation_site["concrete_addresses"]
            )
            allocation_site["concrete_addresses"] = OrderedSet()
        # logger.info(
        #    f"function {env.get_readable_allocation_site(allocation_site_id)} summary addresses: {len(allocation_site['summary_addresses'])} concrete addresses: {len(allocation_site['concrete_addresses'])}"
        # )

    def has_converged(
        self, env: "Environment", previous_values: dict, current_values: dict
    ) -> bool:
        """
        Check if the allocation site has converged.
        """
        # logger.info(f"{current_values}")
        previous_summary_addresses = previous_values["summary_addresses"]
        current_summary_addresses = current_values["summary_addresses"]
        previous_concrete_addresses = previous_values["concrete_addresses"]
        current_concrete_addresses = current_values["concrete_addresses"]
        if len(previous_summary_addresses) != 1 or len(current_summary_addresses) != 1:
            return False
        if len(current_concrete_addresses) > 0:
            return False
        previous_summary_value = list(previous_summary_addresses.values())[0]
        current_summary_value = list(current_summary_addresses.values())[0]
        return previous_summary_value == current_summary_value


class MergeByRecency(MergeStrategy):
    def __init__(self, widening_policy: WideningPolicy):
        self.widening_policy = widening_policy

    def merge(self, env: "Environment", allocation_site_id: str) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if allocation_site["type"] == "object" or allocation_site["type"] == "class":
            self.merge_object_by_recency(env, allocation_site_id)
        elif allocation_site["type"] == "function":
            self.merge_function_by_recency(env, allocation_site_id)
        elif allocation_site["type"] == "heap_frame":
            self.merge_heap_frame_by_recency(env, allocation_site_id)
        else:
            raise Exception(f"Unknown allocation site type {allocation_site['type']}")

        # At this point, the allocation site has been merged, and there should only be one abstract address for the allocation side. We can now widen the address. We won't widen the concrete addresses for this one, because we want to keep it concrete.
        self.widening_policy.widen(
            env, allocation_site["summary_addresses"][0]
        )  # type: ignore

    def merge_allocation_site_class(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        logger.info(f"Merging allocation site {allocation_site_id} as class")
        allocation_site = env.allocation_sites[allocation_site_id]
        should_simplify = False
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "class"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)

    def merge_allocation_site_heap_frame(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        # logger.info(f"merging heap frame allocation site {allocation_site_id} {allocation_site}")
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        # If there are no concrete addresses but the abstract address is modified, we should still simplify the heap frame
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "heap_frame"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)
            modified = env.join_addresses(
                abstract_address, concrete_address, add_null=False
            )
        env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
        allocation_site["concrete_addresses"] = OrderedSet()

    def merge_object_by_recency(
        self,
        env: "Environment",
        allocation_site_id: str,
        add_null: bool = False,  # TODO I should implement better may/must pointer analysis.
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {}, add_proto=False, allocation_site=allocation_site_id
            )  # Because this is a new address, field_in_object will always return Nothing. So everything will be NULL.
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][0]
        most_recent_address = allocation_site["most_recent_address"]
        addresses_to_abstract = OrderedSet()
        for concrete_address in allocation_site["concrete_addresses"]:
            if (
                # concrete_address == abstract_address
                concrete_address
                == most_recent_address
            ):
                continue
            assert (
                env.get_object_type(concrete_address) == "object"
                or env.get_object_type(concrete_address) == "class"
            )
            # We should abstract this otherwise we could get pointers from the
            # abstract heap to the concrete heap after joining
            env.move_object_to_abstract_heap(concrete_address)

            addresses_to_abstract.add(concrete_address)
            modified = env.join_addresses(
                abstract_address, concrete_address, add_null=add_null
            )  # this is pretty simple. Just combines them.
        env.abstract_addresses(addresses_to_abstract, abstract_address)
        allocation_site["concrete_addresses"] = OrderedSet([most_recent_address])

    def merge_function_by_recency(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if (
            len(allocation_site["summary_addresses"]) == 0
            and len(allocation_site["concrete_addresses"]) == 1
        ):
            concrete_function = list(allocation_site["concrete_addresses"])[0]
            env.move_object_to_abstract_heap(concrete_function)
            abstract_function = concrete_function
            allocation_site["summary_addresses"].add(abstract_function)
            allocation_site["concrete_addresses"] = OrderedSet()
        if len(allocation_site["concrete_addresses"]) > 1 or (
            len(allocation_site["summary_addresses"]) > 0
            and len(allocation_site["concrete_addresses"]) > 0
        ):
            env.merge_functions_with_same_allocation_site(
                allocation_site["concrete_addresses"]
            )
            allocation_site["concrete_addresses"] = OrderedSet()

    def has_converged(
        self, env: "Environment", previous_values: dict, current_values: dict
    ) -> bool:
        """
        Check if the allocation site has converged.
        """
        previous_summary_addresses = previous_values["summary_addresses"]
        current_summary_addresses = current_values["summary_addresses"]
        previous_concrete_addresses = previous_values["concrete_addresses"]
        current_concrete_addresses = current_values["concrete_addresses"]

        if len(previous_summary_addresses) != 1 or len(current_summary_addresses) != 1:
            return False
        if (
            len(previous_concrete_addresses) != 1
            or len(current_concrete_addresses) != 1
        ):
            return False
        previous_summary_value = list(previous_summary_addresses.values())[0]
        current_summary_value = list(current_summary_addresses.values())[0]
        previous_concrete_value = list(previous_concrete_addresses.values())[0]
        current_concrete_value = list(current_concrete_addresses.values())[0]
        if previous_summary_value != current_summary_value:
            return False
        if previous_concrete_value != current_concrete_value:
            return False
        return True

class MergeByRole(MergeStrategy):
    def __init__(self, field: str, widening_policy: WideningPolicy): # merge objects with the same value for `field`
        self.widening_policy = widening_policy
        self.field = field

    def merge(self, env: "Environment", allocation_site_id: str) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if allocation_site["type"] == "object" or allocation_site["type"] == "class":
            self.merge_object_by_role(env, allocation_site_id)
        elif allocation_site["type"] == "function":
            self.merge_function_by_recency(env, allocation_site_id)
        elif allocation_site["type"] == "heap_frame":
            self.merge_heap_frame_by_recency(env, allocation_site_id)
        else:
            raise Exception(f"Unknown allocation site type {allocation_site['type']}")

        # At this point, the allocation site has been merged, and there should only be one abstract address for the allocation side. We can now widen the address. We won't widen the concrete addresses for this one, because we want to keep it concrete.
        self.widening_policy.widen(
            env, allocation_site["summary_addresses"][0]
        )  # type: ignore
        for summary_address in allocation_site["summary_addresses"]:
            self.widening_policy.widen(
                env, summary_address
            )  # type: ignore

    def merge_allocation_site_class(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        logger.info(f"Merging allocation site {allocation_site_id} as class")
        allocation_site = env.allocation_sites[allocation_site_id]
        should_simplify = False
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "class"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)

    def merge_allocation_site_heap_frame(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        # logger.info(f"merging heap frame allocation site {allocation_site_id} {allocation_site}")
        if len(allocation_site["summary_addresses"]) == 0:
            abstract_address = env.add_object_to_abstract_heap(
                {},
                add_proto=False,
                allocation_site=allocation_site_id,
                allocation_site_type="heap_frame",
            )
            allocation_site["summary_addresses"].add(abstract_address)
        else:
            abstract_address = allocation_site["summary_addresses"][
                0
            ]  # in the baseline (this class) there is only one abstract address per allocation site
        # If there are no concrete addresses but the abstract address is modified, we should still simplify the heap frame
        for concrete_address in allocation_site["concrete_addresses"]:
            assert env.get_object_type(concrete_address) == "heap_frame"
            if concrete_address == abstract_address:
                continue
            env.move_object_to_abstract_heap(concrete_address)
            modified = env.join_addresses(
                abstract_address, concrete_address, add_null=False
            )
        env.abstract_addresses(allocation_site["concrete_addresses"], abstract_address)
        allocation_site["concrete_addresses"] = OrderedSet()

    def merge_object_by_role(
        self,
        env: "Environment",
        allocation_site_id: str,
        add_null: bool = False,  # TODO I should implement better may/must pointer analysis.
    ) -> None:

        allocation_site = env.allocation_sites[allocation_site_id]
        readable_allocation_site = env.get_readable_allocation_site(
            allocation_site_id
        )
        allocation_site_value_str = Util.get_allocation_site_value_str(
            env, allocation_site_id
        )
        logger.info(
            f"Merging allocation site {readable_allocation_site} as object by role"
        )
        logger.info(f"Allocation site {readable_allocation_site} {allocation_site_value_str}\n length of concrete addresses: {len(allocation_site['concrete_addresses'])}")

        logger.info(f"concrete addresses: {allocation_site['concrete_addresses']}")
        
        summary_addresses = allocation_site["summary_addresses"]
        logger.info(f"concrete addresses: {allocation_site['concrete_addresses']}")
        for concrete_address in allocation_site["concrete_addresses"]:
            logger.info(f"Concrete address {concrete_address}")
            assert (
                env.get_object_type(concrete_address) == "object"
                or env.get_object_type(concrete_address) == "class"
            )
            concrete_obj = env.lookup_address(concrete_address)
            if self.field not in concrete_obj:
                logger.info(
                    f"Field {self.field} not in concrete object {concrete_address}")
                continue # TODO we should find a way to merge them anyways.
            concrete_value = concrete_obj[self.field].get_all_values()
            if not len(summary_addresses):
                # If there are no summary addresses, we can just add the concrete address to the summary addresses
                env.move_object_to_abstract_heap(concrete_address)
                allocation_site["summary_addresses"].add(concrete_address)
                continue
            merged = False
            for summary_address in summary_addresses:
                summary_obj = env.lookup_address(summary_address)
                if self.field in summary_obj:
                    summary_value = summary_obj[self.field].get_all_values()
                    logger.info(f"comparing {summary_value} with {concrete_value}")
                    if summary_value == concrete_value:
                        env.move_object_to_abstract_heap(concrete_address)
                        env.join_addresses(summary_address, concrete_address, add_null=add_null)
                        env.abstract_addresses([concrete_address], summary_address)
                        merged = True
                        break
            if not merged:
                env.move_object_to_abstract_heap(concrete_address)
                allocation_site["summary_addresses"].add(concrete_address)
            # We should abstract this otherwise we could get pointers from the
            # abstract heap to the concrete heap after joining
        allocation_site["concrete_addresses"] = OrderedSet([])

    def merge_function_by_recency(
        self, env: "Environment", allocation_site_id: str
    ) -> None:
        allocation_site = env.allocation_sites[allocation_site_id]
        if (
            len(allocation_site["summary_addresses"]) == 0
            and len(allocation_site["concrete_addresses"]) == 1
        ):
            concrete_function = list(allocation_site["concrete_addresses"])[0]
            env.move_object_to_abstract_heap(concrete_function)
            abstract_function = concrete_function
            allocation_site["summary_addresses"].add(abstract_function)
            allocation_site["concrete_addresses"] = OrderedSet()
        if len(allocation_site["concrete_addresses"]) > 1 or (
            len(allocation_site["summary_addresses"]) > 0
            and len(allocation_site["concrete_addresses"]) > 0
        ):
            env.merge_functions_with_same_allocation_site(
                allocation_site["concrete_addresses"]
            )
            allocation_site["concrete_addresses"] = OrderedSet()

    def has_converged(
        self, env: "Environment", previous_values: dict, current_values: dict
    ) -> bool:
        """
        Check if the allocation site has converged.
        """
        previous_summary_addresses = previous_values["summary_addresses"]
        current_summary_addresses = current_values["summary_addresses"]
        previous_concrete_addresses = previous_values["concrete_addresses"]
        current_concrete_addresses = current_values["concrete_addresses"]

        if len(previous_summary_addresses) != 1 or len(current_summary_addresses) != 1:
            return False
        if (
            len(previous_concrete_addresses) != 1
            or len(current_concrete_addresses) != 1
        ):
            return False
        previous_summary_value = list(previous_summary_addresses.values())[0]
        current_summary_value = list(current_summary_addresses.values())[0]
        previous_concrete_value = list(previous_concrete_addresses.values())[0]
        current_concrete_value = list(current_concrete_addresses.values())[0]
        if previous_summary_value != current_summary_value:
            return False
        if previous_concrete_value != current_concrete_value:
            return False
        return T