"""_summary_

Handles all the things related to function calls.
"""

import copy
from typing import TYPE_CHECKING
from absint_ai.Environment.types.Type import *
from absint_ai.Environment.types.MemoizedFunction import MemoizedFunction
from absint_ai.utils.Util import *
from absint_ai.Environment.memory.RecordResult import RecordResult

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment


def copy_value_at_address(
    env: "Environment", address: Address, parent: Address = None
) -> Address:
    if address.get_addr_type() == "concrete":
        obj = copy.deepcopy(env.concrete_heap.get(address))
        if "allocation_site" in obj["__meta__"]:
            allocation_site = obj["__meta__"]["allocation_site"]
        else:
            allocation_site = None
        if parent:
            obj["__meta__"]["__parent__"] = parent
        addr = env.add_object_to_heap(obj, allocation_site=allocation_site)
        return addr
    elif address.get_addr_type() == "abstract":
        obj = copy.deepcopy(env.abstract_heap.get(address))
        if "allocation_site" in obj["__meta__"]:
            allocation_site = obj["__meta__"]["allocation_site"]
        else:
            allocation_site = None
        if parent:
            obj["__meta__"]["__parent__"] = parent
        addr = env.add_object_to_abstract_heap(obj, allocation_site=allocation_site)
        return addr
    else:
        raise Exception(f"Unknown address type {address.get_addr_type()}")


# given a list of addresses allocated by a memoized function, reallocate them inside of the object returned by the function recursively
def reallocate_addresses_for_object(
    env: "Environment",
    addr: Address,
    allocated_addresses: list[Address],
    searched_addresses: list[Address],
) -> None:
    if not addr:
        return
    if addr.get_addr_type() == "concrete":
        obj = env.concrete_heap.get(addr)
    elif addr.get_addr_type() == "abstract":
        obj = env.abstract_heap.get(addr)
    else:
        raise Exception(f"Unknown address type {addr.get_addr_type()}")
    for record_result in obj.values():
        if not record_result:
            continue
        if isinstance(record_result, dict):  # this is the __meta__ key
            if "__parent__" in record_result:
                record_result = record_result["__parent__"]
                if record_result in allocated_addresses:
                    new_address = copy_value_at_address(env, record_result)
                    # record_result.remove_value(record_result)
                    # record_result.add_value(new_address)
                    searched_addresses.append(record_result)
                    reallocate_addresses_for_object(
                        env, new_address, allocated_addresses, searched_addresses
                    )
                else:
                    reallocate_addresses_for_object(
                        env, record_result, allocated_addresses, searched_addresses
                    )
        else:
            for value in record_result.get_all_values():
                if isinstance(value, Address) and value not in searched_addresses:
                    if value in allocated_addresses:
                        new_address = copy_value_at_address(env, value)
                        record_result.remove_value(value)
                        record_result.add_value(new_address)
                        searched_addresses.append(value)
                        reallocate_addresses_for_object(
                            env, new_address, allocated_addresses, searched_addresses
                        )
                    elif value not in searched_addresses:
                        searched_addresses.append(value)
                        reallocate_addresses_for_object(
                            env, value, allocated_addresses, searched_addresses
                        )

    return


# Re-allocates anything that was allocated by the original function
# FIXME Need to add any reallocated objects to the allocation sites object
def return_subheap(
    env: "Environment", memoized_function: MemoizedFunction
) -> list[Type]:
    return_values = memoized_function.get_return_values()
    allocated_addresses = memoized_function.get_allocated_addresses()
    # logger.info(
    #    f"returning subheap for {memoized_function}, allocated addresses: {allocated_addresses}, return values: {return_values}"
    # )
    result: list[Type] = []
    for return_value in return_values:
        if isinstance(return_value, Address):
            return_value_info = env.get_meta(return_value)
            if env.get_object_type(return_value) == "object":
                if return_value not in allocated_addresses:
                    result.append(return_value)
                else:
                    new_address = copy_value_at_address(env, return_value)
                    reallocate_addresses_for_object(
                        env, new_address, allocated_addresses, []
                    )
                    result.append(new_address)
            elif env.get_object_type(return_value) == "function":
                if return_value not in allocated_addresses:
                    result.append(return_value)
                else:
                    if return_value_info.get("__parent__") in allocated_addresses:
                        new_address = copy_value_at_address(env, return_value_info.get("__parent__"))  # type: ignore
                        reallocate_addresses_for_object(
                            env, new_address, allocated_addresses, []
                        )
                        new_function = copy_value_at_address(
                            env, return_value, parent=new_address
                        )
                        result.append(new_function)
                    else:
                        new_function = copy_value_at_address(env, return_value)
                        reallocate_addresses_for_object(
                            env, new_function, allocated_addresses, []
                        )
                        result.append(new_function)
        else:
            result.append(return_value)
    return result


def get_memoized_function(
    env: "Environment", func: Address, caller_params: dict
) -> list[Type] | None:
    assert env.get_object_type(func) == "function"
    function_id = env.get_meta(func)["function_id"]
    caller_function_id = function_id
    for memoized_function in env.memoized_functions:
        if (
            caller_function_id == memoized_function.get_function_id()
            and caller_params == memoized_function.get_params()
        ):
            if not memoized_function.is_recursive_call():
                touched_addresses = memoized_function.get_touched_address_values()
                touched_primitives = memoized_function.get_touched_primitives()
                touched_addresses_hash = memoized_function.get_touched_addresses_hash()
                touched_primitives_hash = (
                    memoized_function.get_touched_primitives_hash()
                )
                if not touched_addresses_changed(
                    env, touched_addresses, touched_addresses_hash
                ) and not touched_primitives_changed(
                    env, touched_primitives, touched_primitives_hash
                ):
                    return_value = return_subheap(
                        env, memoized_function
                    )  # memoized_function.get_return_values()
                    return return_value
    return None


def memoize_and_return_from_function(
    env: "Environment",
    func: Address,
    params_to_memoize: dict[str, list[Type]],
    function_file_path: str,
    call_site_file_path: str,
) -> None:
    touched_addresses = get_touched_addresses(env)
    touched_addresses_hash = hash_address_values(touched_addresses)
    touched_primitives = get_touched_primitives(env)
    touched_primitives_hash = hash(str(touched_primitives))
    allocated_addresses = env.cur_stack_frame.get_allocated_addresses()
    function_id = env.get_meta(func)["function_id"]
    memoized_function = MemoizedFunction(
        func=func,
        function_id=function_id,
        params=params_to_memoize,
        touched_address_values=touched_addresses,
        touched_addresses_hash=touched_addresses_hash,
        touched_primitives=touched_primitives,
        touched_primitives_hash=touched_primitives_hash,
        allocated_addresses=allocated_addresses,
        return_values=env.get_return_values(),
    )
    function_file_path = convert_path_to_underscore(function_file_path)
    call_site_file_path = convert_path_to_underscore(call_site_file_path)
    if is_recursive_call(env, func, params_to_memoize):
        remove_recursive_calls(env, func, params_to_memoize)
    env.memoized_functions.append(memoized_function)
    touched_address_for_stale_stack_frame = get_touched_addresses(env)
    env.stack[function_file_path].pop()
    env.cur_stack_frame = env.stack[call_site_file_path][-1]
    for (
        address
    ) in (
        touched_address_for_stale_stack_frame
    ):  # TODO should we also add touched primitives to the current stack frame?
        add_touched_address(env, address)


def is_recursive_call(env: "Environment", func: Address, caller_params: dict) -> bool:
    caller_function_id = env.get_meta(func)["function_id"]
    for memoized_function in env.memoized_functions:
        if (
            caller_function_id
            == memoized_function.get_function_id()
            # and caller_params == memoized_function.get_params() # TODO we are ignoring parameters for now, because they can be ever-increasing.
        ):
            if memoized_function.is_recursive_call():
                return True
    return False


def add_recursive_call(env: "Environment", func: Address, params: dict) -> None:
    function_id = env.get_meta(func)["function_id"]
    env.memoized_functions.append(MemoizedFunction(func, function_id, params, None, None, None, None, None, None))  # type: ignore


def remove_recursive_calls(
    env: "Environment", func: Address, caller_params: dict
) -> None:
    caller_function_id = env.get_meta(func)["function_id"]
    for memoized_function in env.memoized_functions:
        if (
            caller_function_id == memoized_function.get_function_id()
            and caller_params == memoized_function.get_params()
        ):
            if memoized_function.is_recursive_call():
                env.memoized_functions.remove(memoized_function)


def set_has_recursive_call(env: "Environment", has_recursive_call: bool) -> None:
    env.cur_stack_frame.set_has_recursive_call(has_recursive_call)


def replace_recursive_placeholders_for_address(
    env: "Environment",
    return_value: Type,
    replacement: list[Type],
    checked_addresses: list = [],
) -> None:
    if isinstance(return_value, Address):
        checked_addresses.append(return_value)
        if return_value.get_addr_type() == "concrete":
            heap_frame = env.concrete_heap.get(return_value)
        elif return_value.get_addr_type() == "abstract":
            heap_frame = env.abstract_heap.get(return_value)
        else:
            raise Exception(f"Unknown address type {return_value.get_addr_type()}")
        for key in heap_frame:
            if key == "__meta__":
                if "__parent__" in heap_frame[key]:
                    parent = heap_frame[key]["__parent__"]
                    if isinstance(parent, Address) and parent not in checked_addresses:
                        replace_recursive_placeholders_for_address(
                            env,
                            parent,
                            replacement,
                            checked_addresses=checked_addresses,
                        )
            else:
                if heap_frame[key].has_recursive_placeholder():
                    logger.info(
                        f"Replacing recursive placeholder in {heap_frame[key]} with {replacement}"
                    )
                    heap_frame[key].remove_recursive_placeholder()
                    heap_frame[key].add_values(replacement)


def get_return_values(
    env: "Environment",
) -> list[Type]:  # Handle recursive calls, we can backfill them in here
    return_values = env.cur_stack_frame.get_return_values()
    if env.cur_stack_frame.has_recursive_call:
        for (
            return_value
        ) in (
            return_values
        ):  # TODO This is very slow. We should find a way to optimize this, since it needs to recurse very far
            if isinstance(return_value, Address):
                replace_recursive_placeholders_for_address(
                    env, return_value, return_values
                )
    return return_values


def return_from_function(
    env: "Environment", func: Address | None, file_path: str
) -> None:
    # We should get the trace of the stack frame that we're about to pop and add it to the current stack frame.
    touched_address_for_stale_stack_frame = get_touched_addresses(env)
    file_path = convert_path_to_underscore(file_path)
    env.stack[file_path].pop()
    env.cur_stack_frame = env.stack[file_path][-1]
    for address in touched_address_for_stale_stack_frame:
        add_touched_address(env, address)


def add_touched_record_result(
    env: "Environment",
    name: str,
    record_result: RecordResult,
    add_primitives: bool = False,
) -> None:
    for value in record_result.get_all_values():
        if isinstance(value, Address):
            add_touched_address(env, value)
        elif add_primitives:
            add_touched_primitive(env, name, value)


def add_touched_address(env: "Environment", address: Address) -> None:
    env.cur_stack_frame.add_touched_address(
        address, env.lookup_and_derive_address(address)
    )


def add_touched_primitive(env: "Environment", name: str, value: Type) -> None:
    env.cur_stack_frame.add_touched_primitive(name, value)


def get_touched_addresses(env: "Environment") -> dict[Address, object]:
    return env.cur_stack_frame.get_touched_addresses()


def get_touched_primitives(env: "Environment") -> dict[str, Type]:
    return env.cur_stack_frame.get_touched_primitives()


def touched_addresses_changed(
    env: "Environment", touched_addresses: dict[Address, object], prev_hash: int
) -> bool:
    addresses = touched_addresses.keys()
    result: dict[Address, object] = {}
    for address in addresses:
        result[address] = env.lookup_and_derive_address(address)
    return hash_address_values(result) != prev_hash


def touched_primitives_changed(
    env: "Environment", touched_primitives: dict[str, list[Type]], prev_hash: int
) -> bool:
    all_cur_values = {}
    for name in touched_primitives:
        cur_values = [
            _
            for _ in env.lookup(name, add_touched=False).get_all_values()
            if not isinstance(_, Address)
        ]
        all_cur_values[name] = cur_values
    return hash(str(all_cur_values)) != prev_hash
