from absint_ai.Environment.types.Type import *
from absint_ai.utils.Logger import logger  # type: ignore
from .RecordResult import RecordResult
from ordered_set import OrderedSet
from .AbstractHeap import AbstractHeap
from absint_ai.utils.Util import *

global heap_id
heap_id = 0


def concrete_heap_id_reset() -> None:
    global heap_id
    heap_id = 0


class ConcreteHeap:
    def __init__(
        self,
        concrete_heap: dict[str, dict[str | AbstractType, RecordResult]] | None = None,
        track_modified_addresses: bool = False,
    ):
        if concrete_heap is None:
            self.concrete_heap = {}
        else:
            self.concrete_heap = concrete_heap
        self.track_modified_addresses = track_modified_addresses
        self.modified_addresses = (
            OrderedSet()
        )  # use this to track what needs to be merged after a conditional

    def add(self, obj: dict) -> str:
        global heap_id
        converted_to_record_results = {}
        for key, value in obj.items():
            if key == "__meta__":
                converted_to_record_results[key] = value
            else:
                if isinstance(value, list) or isinstance(value, OrderedSet):
                    converted_to_record_results[key] = RecordResult("local", value)
                elif isinstance(value, RecordResult):
                    converted_to_record_results[key] = value
                else:
                    converted_to_record_results[key] = RecordResult("local", [value])
                    # raise ValueError(
                    #    f"Invalid value type being added to heap: {value}"
                    # )
        self.concrete_heap[str(heap_id)] = converted_to_record_results
        if self.track_modified_addresses:
            self.modified_addresses.add(str(heap_id))
        heap_id += 1
        return str(heap_id - 1)

    def contains(self, heap_id: int | str) -> bool:
        return heap_id in self.concrete_heap

    def merge(self, other: "ConcreteHeap") -> None:
        for address in other.addresses():
            if address.get_value() not in self.concrete_heap:
                self.concrete_heap[address.get_value()] = other.get(address)
            else:
                for key, value in other.get(address).items():
                    current_value = self.concrete_heap[address.get_value()].get(key)
                    if key == "__meta__":
                        if (
                            "enumerable_values" in current_value
                            and "enumerable_values" in value
                        ):
                            current_value["enumerable_values"].update(
                                value["enumerable_values"]
                            )
                        continue
                    if current_value is None:
                        self.concrete_heap[address.get_value()][key] = value
                    else:
                        for val in value.get_all_values():
                            if not item_contained_in_list(
                                val,
                                self.concrete_heap[address.get_value()][
                                    key
                                ].get_all_values(),
                            ):
                                current_value.add_value(val)
        if self.track_modified_addresses:
            self.modified_addresses.update(other.modified_addresses)

    def get(self, addr: Address) -> dict[str | AbstractType, RecordResult]:
        if addr.get_addr_type() != "concrete":
            raise ValueError(f"Address {addr} is not concrete")
        heap_id = addr.get_value()
        if heap_id not in self.concrete_heap:
            raise ValueError(f"Address {heap_id} not found in concrete heap")
        heap_result = self.concrete_heap[heap_id]
        return heap_result

    def pop(self, addr: Address) -> None:
        if addr.get_addr_type() != "concrete":
            raise ValueError(f"Address {addr} is not concrete")
        heap_id = addr.get_value()
        if heap_id in self.concrete_heap:
            self.concrete_heap.pop(heap_id)
        else:
            raise ValueError(f"Address {heap_id} not found in concrete heap")

    def overwrite_address(
        self, addr: Address, value: dict[str | AbstractType, RecordResult]
    ) -> None:
        if addr.get_addr_type() != "concrete":
            raise ValueError(f"Address {addr} is not concrete")
        heap_id = addr.get_value()
        if heap_id not in self.concrete_heap:
            raise ValueError(f"Address {heap_id} not found in concrete heap")
        if self.track_modified_addresses:
            self.modified_addresses.add(heap_id)
        self.concrete_heap[heap_id] = value

    def update(
        self,
        heap_id: str,
        property: str,
        values: list[Type],
        overwrite: bool = False,
    ) -> None:
        if isinstance(values, list) or isinstance(values, OrderedSet):
            values = RecordResult("local", values)  # type: ignore
        # logger.info(f"Updating {self.concrete_heap[heap_id]}")
        if (
            isinstance(property, (int, float))
            and baseType.NUMBER in self.concrete_heap[heap_id]
        ):
            record_result_for_field = self.concrete_heap[heap_id][baseType.NUMBER]
            for value in values:
                if not isinstance(value, Address):
                    if not item_contained_in_list(
                        value, record_result_for_field.get_all_values()
                    ):
                        record_result_for_field.add_value(value)
                else:
                    self.concrete_heap[heap_id][baseType.NUMBER].add_value(value)
        elif (
            isinstance(property, str) and baseType.STRING in self.concrete_heap[heap_id]
        ):
            record_result_for_field = self.concrete_heap[heap_id][baseType.STRING]
            for value in values:
                if not isinstance(value, Address):
                    if not item_contained_in_list(
                        value, record_result_for_field.get_all_values()
                    ):
                        record_result_for_field.add_value(value)
                else:
                    self.concrete_heap[heap_id][baseType.STRING].add_value(value)
        else:
            if (
                property
                not in self.concrete_heap[heap_id]["__meta__"]["enumerable_values"]
            ):
                self.concrete_heap[heap_id]["__meta__"]["enumerable_values"].append(
                    property
                )
            if property not in self.concrete_heap[heap_id]:
                self.concrete_heap[heap_id][property] = values
            elif property in self.concrete_heap[heap_id] and overwrite:
                # self.concrete_heap[heap_id][property] = values  # type: ignore
                self.concrete_heap[heap_id][property].clear_values()
                for value in values:
                    self.concrete_heap[heap_id][property].add_value(value)
            else:
                for value in values:
                    if not item_contained_in_list(
                        value, self.concrete_heap[heap_id][property].get_all_values()
                    ):
                        self.concrete_heap[heap_id][property].add_value(value)
            if self.track_modified_addresses:
                self.modified_addresses.append(heap_id)

    def addresses(self) -> list[Address]:
        return [
            Address(value=addr, addr_type="concrete")
            for addr in self.concrete_heap.keys()
        ]
