from absint_ai.Environment.types.Type import *
from typing_extensions import Self
from absint_ai.utils.Util import *
from ordered_set import OrderedSet
from absint_ai.utils.Logger import logger  # type: ignore


class RecordResultIterator:
    def __init__(self, data: OrderedSet) -> None:
        self.data = list(data)
        self.index = 0

    def __next__(self) -> Type:
        if self.index >= len(self.data):
            raise StopIteration
        item = self.data[self.index]
        self.index += 1
        return item


# Represents all possible values for a variable.
# RecordResults are always associated with a variable name, and are stored either on the stack or in the heap.
class RecordResult:
    def __init__(
        self,
        _type: str,
        values: list[Type] | None = None,
    ):
        self.counter = 0
        self.values = OrderedSet(values) if values else OrderedSet()
        self._type: str = (
            "local"  # denotes if this is a record result for stored in the local stack/heap or in the global heap
        )

    def __str__(self) -> str:
        return f"{self._type} RecordResult: {self.values}"

    def __len__(self) -> int:
        return len(self.values)

    def __repr__(self) -> str:
        return f"{self._type} RecordResult: {self.values}"

    def __hash__(self) -> int:
        return hash(tuple(self.values))

    def __iter__(self) -> RecordResultIterator:
        self.values = OrderedSet(list(self.values))
        return RecordResultIterator(self.values)

    def __getitem__(self, index: int) -> Type:
        return self.values[index]

    def __next__(self) -> None:
        raise Exception("This shouldn't be called!")

    def __eq__(self, other: object) -> bool:
        if not isinstance(other, RecordResult):
            return False
        return self.values == other.values

    def get_concrete_values(self) -> list[Type]:
        concrete_values: list[Type] = []
        for value in self.values:
            if isinstance(value, Primitive):
                concrete_values.append(value)
            elif isinstance(value, Address):
                if value.get_addr_type() == "concrete":
                    concrete_values.append(value)
        return concrete_values

    def get_abstract_values(self) -> list[Type]:
        abstract_values: list[Type] = []
        for value in self.values:
            if isinstance(value, Address):
                if value.get_addr_type() == "abstract":
                    abstract_values.append(value)
            elif isinstance(value, AbstractType):
                abstract_values.append(value)
        return abstract_values

    def only_contains_primitives(self) -> bool:
        return all(
            isinstance(value, Primitive) or isinstance(value, AbstractType)
            for value in self.values
        )

    def contains_primitives(self) -> bool:
        return any(isinstance(value, Primitive) for value in self.values)

    def get_all_primitives(self) -> list[Type]:
        return OrderedSet(
            [value for value in self.values if isinstance(value, Primitive)]
        )

    def merge_primitives(self) -> None:
        non_abstract_values = [
            value for value in self.values if isinstance(value, Primitive)
        ]
        abstract_values = [
            value for value in self.values if isinstance(value, AbstractType)
        ]
        for non_abstract_value in non_abstract_values:
            abstract_value = value_to_abstract_type(non_abstract_value.get_value())
            if abstract_value not in abstract_values:
                abstract_values.append(abstract_value)
        for value in non_abstract_values:
            self.remove_value(value)
        self.add_values(abstract_values)

    def only_contains_addresses(self) -> bool:
        return all(isinstance(value, Address) for value in self.values)

    def get_all_addresses(self) -> OrderedSet[Address]:
        return OrderedSet(
            [value for value in self.values if isinstance(value, Address)]
        )

    def get_all_values(self) -> list[Type]:
        assert type(self.values) == OrderedSet
        return self.values

    def clear_values(self) -> None:
        self.values = OrderedSet()

    def deduplicate_values(self) -> None:
        self.values = OrderedSet(list(self.values))

    def has_recursive_placeholder(self) -> bool:
        return any(value == baseType.RECURSIVE_PLACEHOLDER for value in self.values)

    def remove_recursive_placeholder(self) -> None:
        self.values = OrderedSet(
            [value for value in self.values if value != baseType.RECURSIVE_PLACEHOLDER]
        )

    # Merge the values of another RecordResult into this one, returns whether a new value was added.
    def merge_other_record_result(self, other: Self) -> bool:
        added_new_values = False
        for value in other:
            if value not in self.values:
                self.add_value(value)
                added_new_values = True
        return added_new_values

    def add_values(self, values: list[Type]) -> None:
        for value in values:
            self.add_value(value)

    def remove_value(self, value: Type) -> None:
        self.values = OrderedSet([val for val in self.values if val != value])

    def add_value(self, value: object) -> None:
        if not isinstance(value, Type):
            logger.error(f"Invalid value type: {type(value)}")
            return
        self.values.add(value)

    def is_subset(self, other: Self) -> bool:
        return all(value in other.get_all_values() for value in self.get_all_values())

    def get_type(self) -> str | None:
        return self._type

    def set_type(self, str_type: str) -> None:
        self._type = str_type
