import json
import beeprint
from absint_ai.Environment.types.Type import *
from absint_ai.utils.Util import *
from typing import TYPE_CHECKING

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


def pretty_print(env: "Environment", derive: bool = False) -> str:
    new_result: dict[str, list] = {}
    for key in env.get_all_reachable_variable_names():
        if key == "module" or key == "exports":
            continue
        new_result[key] = []
        lookup_results = [
            _.get_value() if isinstance(_, Primitive) else _
            for _ in env.lookup(key).get_all_values()
        ]
        if derive:
            lookup_results = env.lookup_and_derive(key)
        for result in lookup_results:
            if result in new_result[key]:
                continue
            new_result[key].append(result)
    to_pop = [key for key in new_result if not new_result[key]]
    for key in to_pop:
        new_result.pop(key)
    result = json.dumps(serialize_keys(new_result), indent=2)
    with open("/tmp/env.json", "w") as f:
        f.write(result)
    return result


def pretty_print_variables(env: "Environment", variable_names: list[str]) -> str:
    new_result: dict[str, list] = {}
    for key in variable_names:
        if key == "this":
            continue
        new_result[key] = []
        for value in env.lookup(key).get_all_values():
            if isinstance(value, Primitive):
                new_result[key].append(value.get_value())
            else:
                new_result[key].append(value)
    result = json.dumps(serialize_keys(new_result), indent=2)
    return result


def pretty_print_allocation_sites(env: "Environment") -> str:
    result = beeprint.pp(env.allocation_sites, output=False)
    return result


def pretty_print_all_variables(env: "Environment") -> str:
    new_result = {}
    for key in env.get_all_variable_names():
        # if key == "document" or key == "element" or key == "module":
        #    continue
        new_result[key] = env.lookup_and_derive(key)
    result = json.dumps(serialize_keys(new_result), indent=2)
    return result


def pretty_print_primitives(env: "Environment") -> str:
    result: dict[str, list] = {}
    for key in env.get_all_reachable_variable_names():
        result[key] = []
        for value in env.lookup(key).get_all_values():
            if isinstance(value, Primitive):
                result[key].append(value.get_value())
    result_str = json.dumps(serialize_keys(result), indent=2)
    print(result_str)
    return result_str


def get_all_object_names(env: "Environment") -> list[str]:
    object_names: list[str] = []
    for stack_frames_for_module in env.stack.values():
        for stack_frame in stack_frames_for_module:
            heap_frame = env.lookup_address(stack_frame.get_heap_frame_address())
            for key in heap_frame:
                if key != "__meta__" and key != "__proto__":
                    object_names.append(key)
    return object_names


# partial doesn't print out heap frames
def pretty_print_concrete_heap(env: "Environment", partial: bool = True) -> str:
    result: dict[Address, dict[str, list]] = {}
    for addr in env.concrete_heap.addresses():
        addr_result = env.concrete_heap.get(addr)

        if partial and is_heap_frame(addr_result):
            continue
        addr_values: dict[str, list] = {}
        for k, v in addr_result.items():
            if partial and (
                k == "__meta__" or k == "__proto__" or k == "exports" or k == "module"
            ):
                continue
            if isinstance(k, Type) and k.get_value() == "this":
                continue
            if not partial:
                if k == "__meta__":
                    addr_values[k] = v
                    continue
            all_values = v.get_all_values()
            if isinstance(k, Type):
                key = k.get_value()
            else:
                key = k
            addr_values[key] = []
            for value in all_values:
                if isinstance(value, Address):
                    addr_values[key].append(value)
                elif isinstance(value, Primitive):
                    addr_values[key].append(value.get_value())

        if not sum([len(_) for _ in addr_values.values()]):
            continue
        result[addr] = addr_values

    result_str = json.dumps(serialize_keys(result), indent=2)
    with open("/tmp/concrete_heap.json", "w") as f:
        f.write(result_str)
    return result_str


# Print a list of addresses, ignore heap frames
def pretty_print_addresses(env: "Environment", addresses: list[Address]) -> str:
    result: dict[Address, dict[str, list]] = {}
    for addr in addresses:
        addr_result = env.lookup_address(addr)
        addr_values: dict[str, list] = {}
        for k, v in addr_result.items():
            if k == "__meta__" or k == "__proto__" or k == "exports" or k == "module":
                continue
            if isinstance(k, Type) and k.get_value() == "this":
                continue
            all_values = v.get_all_values()
            if isinstance(k, Type):
                key = k.get_value()
            else:
                key = k
            addr_values[key] = []
            for value in all_values:
                if isinstance(value, Address):
                    addr_values[key].append(value)
                elif isinstance(value, Primitive):
                    addr_values[key].append(value.get_value())
        result[addr] = addr_values
    result_str = json.dumps(serialize_keys(result), indent=2)
    return result_str


def pretty_print_abstract_heap(env: "Environment", partial: bool = True) -> str:
    result: dict[Address, dict[str, list]] = {}
    for addr in env.abstract_heap.addresses():
        addr_result = env.abstract_heap.get(addr)
        if partial and is_heap_frame(addr_result):
            continue
        addr_values: dict[str, list] = {}
        for k, v in addr_result.items():
            if partial and (
                k == "__meta__" or k == "__proto__" or k == "exports" or k == "module"
            ):
                continue
            if isinstance(k, Type) and k.get_value() == "this":
                continue
            if not partial:
                if k == "__meta__":
                    addr_values[k] = v
                    continue
            all_values = v.get_all_values()
            if isinstance(k, Type):
                key = k.get_value()
            else:
                key = k
            addr_values[key] = []
            for value in all_values:
                if isinstance(value, Address):
                    addr_values[key].append(value)
                elif isinstance(value, Primitive):
                    addr_values[key].append(value.get_value())
                elif isinstance(value, AbstractType):
                    addr_values[key].append(value)

        if not sum([len(_) for _ in addr_values.values()]):
            continue
        result[addr] = addr_values
    result_str = json.dumps(serialize_keys(result), indent=2)
    with open("/tmp/abstract_heap.json", "w") as f:
        f.write(result_str)
    return result_str
