import esprima
from absint_ai.Environment.types.Type import *
from .Logger import logger  # type: ignore
from dotmap import DotMap
from ordered_set import OrderedSet
import re
import pythonmonkey as pm
import beeprint
import json
from typing import TYPE_CHECKING
from collections import Counter
from typing import Any
import json
import textwrap
from termcolor import colored

if TYPE_CHECKING:
    from absint_ai.Environment.Environment import Environment
    from absint_ai.Environment.types.Type import AbstractType
    from absint_ai.Environment.types.Type import Address
    from absint_ai.Environment.types.Type import Type


def is_integer(x: object) -> bool:
    if not isinstance(x, str):
        return False
    return x.isdigit()


def get_schema_from_expr(
    expr: esprima.nodes.Node, file_path: str, scope_type: str
) -> str:
    scope_id = f"{convert_path_to_underscore(file_path)}_{expr.loc.start.line}_{expr.loc.end.line}_{expr.loc.end.column}_{scope_type}"
    return scope_id


def allocation_site_from_expr(expr: esprima.nodes.Node, file_path: str) -> str:
    allocation_site = f"{convert_path_to_underscore(file_path)}_{expr.loc.start.line}_{expr.loc.start.column}_{expr.loc.end.line}_{expr.loc.end.column}"
    return allocation_site


def convert_path_to_underscore(path: str) -> str:
    return path.replace("/", "_")


def serialize_keys(obj: object) -> object:
    if isinstance(obj, dict):
        return {str(k): serialize_keys(v) for k, v in obj.items()}
    elif isinstance(obj, (list, tuple, set, OrderedSet)):
        return [serialize_keys(item) for item in obj]
    elif isinstance(obj, Type):
        return str(obj)
    else:
        return obj


""" Addresses are in the form of "Abstract Address(0)" or "Concrete Address(0)" """


def address_from_string(addr_str: object) -> Address:
    if not isinstance(addr_str, str):
        raise ValueError(f"Invalid address string: {addr_str}")
    try:
        addr_type, addr_value = addr_str.split("Address(")
    except ValueError:
        raise ValueError(f"Invalid address string: {addr_str}")
    addr_value = addr_value[:-1]
    if addr_type == "Abstract ":
        return Address(addr_value, "abstract")
    elif addr_type == "Concrete ":
        return Address(addr_value, "concrete")
    else:
        raise ValueError(f"Invalid address type: {addr_type}")


def is_address(addr_str: object) -> bool:
    try:
        address_from_string(addr_str)
        return True
    except Exception:
        return False


def array_to_object(arr: list) -> dict:
    obj = {}
    for i, item in enumerate(arr):
        obj[i] = item
    return obj


def get_chat_history_from_paths(paths: list[tuple[str, str]]) -> list[tuple[str, str]]:
    chat_history = []
    for prompt_path, response_path in paths:
        with open(prompt_path, "r") as f:
            prompt = f.read().strip()
        with open(response_path, "r") as f:
            response = f.read().strip()
        chat_history.append((prompt, response))
    return chat_history


def item_contained_in_list(item: object, lst: list) -> bool:
    if (
        not isinstance(lst, list)
        and not isinstance(lst, set)
        and not isinstance(lst, OrderedSet)
    ):
        return False
    if isinstance(item, list) or isinstance(item, set) or isinstance(item, OrderedSet):
        if len(item) == 0:
            return len(lst) == 0
    if "TOP" in lst:
        return True
    # if first item is a list, check that every object in the item is contained in the list
    if isinstance(item, Primitive):
        if item in lst:
            return True
        item = item.get_value()

    if isinstance(item, list) or isinstance(item, OrderedSet) or isinstance(item, set):
        return all([item_contained_in_list(obj, lst) for obj in item])
    elif item == "TOP" or item == baseType.TOP:
        return "TOP" in lst or baseType.TOP in lst
    elif item == "NULL" or item == baseType.NULL:
        return "NULL" in lst or baseType.NULL in lst or "TOP" in lst
    elif item == "STRING" or item == baseType.STRING:
        return "STRING" in lst or baseType.STRING in lst or "TOP" in lst
    elif item == "NUMBER" or item == baseType.NUMBER:
        return "NUMBER" in lst or baseType.NUMBER in lst
    elif isinstance(item, str) or (
        isinstance(item, Primitive) and isinstance(item.get_value(), str)
    ):
        if "STRING" in lst or baseType.STRING in lst:
            return True
        if item in lst:
            return True
        return False
    elif isinstance(item, bool):
        if "BOOL" in lst or baseType.BOOLEAN in lst:
            return True
        if item in lst:
            return True
        return False
    elif isinstance(item, (int, float)):
        if "NUMBER" in lst or baseType.NUMBER in lst:
            return True
        if item in lst:
            return True
        return False
    elif isinstance(item, dict):
        for obj in lst:
            if object_is_superset(item, obj):
                return True
        return False
    elif isinstance(item, Address):
        if item in lst:
            return True
        for obj in lst:
            if is_address(obj):
                addr = address_from_string(obj)
                if item == addr:
                    return True
    return False


# returns whether obj2 is a superset of obj1
def object_is_superset(obj1: object, obj2: dict[str, list]) -> bool:
    if not isinstance(obj2, dict):
        return False
    if not isinstance(obj1, dict):
        raise Exception(f"Object {obj1} is not a dictionary")
    for key in obj1:
        if key == "__proto__" or key == "__meta__":
            continue

        if key in obj2:
            if not item_contained_in_list(obj1[key], obj2[key]):
                return False
        elif str(key) in obj2:
            if not item_contained_in_list(obj1[key], obj2[str(key)]):
                return False
        elif isinstance(key, (int, float)) or is_integer(key) or key == baseType.NUMBER:
            if "NUMBER" in obj2:
                if not item_contained_in_list(obj1[key], obj2["NUMBER"]):
                    return False
            elif "TOP" in obj2:
                if not item_contained_in_list(obj1[key], obj2["TOP"]):
                    return False
            else:
                return False
        elif isinstance(key, str) or key == baseType.STRING:
            if "STRING" in obj2:
                if not item_contained_in_list(obj1[key], obj2["STRING"]):
                    return False
            elif "TOP" in obj2:
                if not item_contained_in_list(obj1[key], obj2["TOP"]):
                    return False
            else:
                return False
    return True


# returns whether obj2 is a superset of obj1, but works for primtives.
# This works for actual derived values, so at this point they aren't Primitives or Addresses.
def is_superset(obj1_list: list[object], obj2_list: list[object]) -> bool:
    for obj1 in obj1_list:
        if item_contained_in_list(obj1, obj2_list):
            continue
        else:
            return False
    return True


# If the field is in an object or an abstraction of the field, return the abstraction. otherwise, return None
def field_in_object(field: object, obj: dict) -> object:
    if field in obj:
        return field
    elif (
        isinstance(field, (int, float)) or is_integer(field) or field == baseType.NUMBER
    ):
        if "NUMBER" in obj:
            return "NUMBER"
        if baseType.NUMBER in obj:
            return baseType.NUMBER
        if "TOP" in obj:
            return "TOP"
        if baseType.TOP in obj:
            return baseType.TOP

    elif isinstance(field, str) or field == baseType.STRING:
        if "STRING" in obj:
            return "STRING"
        if baseType.STRING in obj:
            return baseType.STRING
        if "TOP" in obj:
            return "TOP"
        if baseType.TOP in obj:
            return baseType.TOP
    return None


def value_to_abstract_type(value) -> AbstractType:
    if isinstance(value, str):
        return baseType.STRING
    elif isinstance(value, bool):
        return baseType.BOOLEAN
    elif isinstance(value, (int, float)):
        return baseType.NUMBER
    elif value is None:
        return baseType.NULL
    else:
        return baseType.TOP


def order_dict(d: dict) -> dict:
    result = {}
    for key in d:
        if isinstance(d[key], dict):
            result[key] = order_dict(d[key])
        elif isinstance(d[key], list):
            result[key] = sorted(d[key])
        else:
            result[key] = d[key]

    return result


def get_identifiers_from_expr(expr, identifiers=None):
    identifiers = [] if identifiers is None else identifiers
    if expr.type == "Identifier":
        identifiers.append(expr.name)
    if expr.type == "Literal":
        identifiers.append(expr.value)
    for key in expr:
        if isinstance(expr[key], list) and key != "range":
            for item in expr[key]:
                get_identifiers_from_expr(item, identifiers)
        elif isinstance(expr[key], DotMap):
            get_identifiers_from_expr(expr[key], identifiers)
    return ".".join([str(_) for _ in identifiers])


def is_heap_frame(obj: dict):
    if "__meta__" in obj and "__parent__" in obj["__meta__"]:
        return True
    return False


def get_num_tokens(tokenizer, message_history):
    total_tokens = 0
    for message in message_history:
        total_tokens += len(tokenizer(message["content"])["input_ids"])
    return total_tokens


def parse_json_from_message(message: str):
    result = re.search("```json(.*)```", message, re.DOTALL)
    if result:
        result = result.group(1)
        try:
            return json.loads(result)
        except Exception as e:
            return {}
    return {}


def is_property_builtin(prop_name: str) -> bool:
    builtins = [
        "addEventListener",
        "appendChild",
        "blur",
        "click",
        "createElement",
        "getElementById",
        "getElementsByClassName",
        "getElementsByTagName",
        "innerHTML",
        "innerText",
        "length",
        "location",
        "log",
        "querySelector",
        "removeChild",
        "removeEventListener",
        "replace",
        "setAttribute",
        "src",
        "style",
        "text",
        "textContent",
        "then",
        "value",
        "window",
    ]
    return prop_name in builtins


def get_variables_in_backticks(message: str) -> list[str]:
    return re.findall(r"`(.*?)`", message)


class ObjectKey(str):
    def __new__(cls, value, may_must: str):
        obj = super().__new__(cls, value)
        obj.may_must = may_must
        return obj

    def __repr__(self):
        return f"{super().__repr__()}(may_must={self.tag})"

    def upper(self):
        return ObjectKey(super().upper(), self.tag)  # Preserve metadata


# hashes the values of all the addresses in the environment for a list of addresses
def hash_address_values(addresses: dict[Address, object]) -> int:
    return hash(str(addresses))


def is_member_expression_static(node: esprima.nodes.Node) -> bool:
    if node.object.type == "Identifier":
        return True
    elif node.object.type == "MemberExpression":
        return is_member_expression_static(node.object)
    return False


def get_identifier_id(node: esprima.nodes.Node) -> str:
    return f"{node.name}_{node.loc.start.line-1}_{node.loc.start.column}"


def is_builtin_object(obj: esprima.nodes.Node) -> bool:
    if obj.type == "Identifier":
        return obj.name in [
            "String",
            "Math",
        ]  # ["Math", "Date", "JSON", "console", "String", "Number"]
    elif obj.type == "MemberExpression":
        return is_builtin_object(obj.object)
    return False


def static_return_for_abstract_type(method: str) -> Type:
    if method == "floor":
        return baseType.NUMBER
    elif method == "ceil":
        return baseType.NUMBER
    elif method == "max":
        return baseType.NUMBER
    return baseType.NUMBER


# these two are slightly different, they're used to coerce between types for executing builtins
def type_to_value(typ: Type, env: "Environment", ignore_proto=False) -> object:
    if isinstance(typ, Primitive):
        if typ.is_regex:
            return "@REGEX:" + typ.get_value()
        return typ.get_value()
    elif isinstance(typ, Address):
        result = env.lookup_and_derive_address(typ, ignore_proto=ignore_proto)
        return result
        if obj_is_array(result):
            keys_sorted = sorted(result.keys(), key=lambda x: int(x))
            result_array = []
            for key in keys_sorted:
                result_array.append(result[key][0])  # FIXME this is kind of a hack
            return result_array
        else:
            return result
    elif isinstance(typ, AbstractType):
        return typ.get_value()
    else:
        raise ValueError(f"Unknown type: {typ}")


def is_int(obj: any) -> bool:
    try:
        int(obj, 10)
        return True
    except Exception:
        return False


def obj_is_array(obj: object) -> bool:
    return all([is_int(key) for key in obj])


def value_to_type(
    value: object, env: "Environment", allocation_site: str = None
) -> Type:
    if isinstance(value, list):
        value = array_to_object(value)
    return env.value_to_type(value, allocation_site=allocation_site)


def is_truthy(value):
    if (
        value is None
        or value == "undefined"
        or value == "null"
        or value == 0
        or value == baseType.NULL
        or value == ""
        or value == "nan"
        or value == "NaN"
        or value == False
        or value == Primitive(False)
        or value == Primitive(0)
        or value == Primitive("")
        or value == Primitive(None)
        or value == Primitive("undefined")
        or value == Primitive("nan")
        or value == Primitive("NaN")
    ):
        return False
    return True


def contains_truthy_value(values: OrderedSet[object]) -> bool:
    if baseType.BOOLEAN in values:
        return True
    return any(is_truthy(value) for value in values)


def only_contains_truthy_values(values: OrderedSet[object]) -> bool:
    if baseType.BOOLEAN in values:
        return False
    return all(is_truthy(value) for value in values)


def prop_in_obj(prop: Type, obj: dict) -> bool:
    if isinstance(prop, AbstractType):
        return False
    if isinstance(prop, Primitive):
        prop = prop.get_value()
    return prop in obj


def escape_string(string: str):
    # Escape special characters in the string
    return string.replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t")


def get_allocation_site_value_str(env: "Environment", allocation_site_id: str) -> str:
    allocation_site_values_raw = env.get_allocation_site_values(
        [allocation_site_id], ignore_heap_frames=True
    )
    readable_allocation_site_id = env.get_readable_allocation_site(allocation_site_id)
    allocation_site_values = {
        readable_allocation_site_id: allocation_site_values_raw[allocation_site_id]
    }
    points_to_info = env.points_to_info_for_allocation_site(allocation_site_id)
    allocation_site_values_raw[allocation_site_id]["points_to"] = list(points_to_info)
    allocation_site_value_str = beeprint.pp(allocation_site_values, output=False)
    return allocation_site_value_str


def canonicalize(value: Any) -> Any:
    if isinstance(value, dict):
        # Recursively canonicalize values and ignore keys
        return frozenset(canonicalize(v) for v in value.values())
    elif isinstance(value, list):
        # Order doesn't matter unless you want it to; here we treat as multiset
        return frozenset(canonicalize(v) for v in value)
    else:
        return value  # Base case for primitives


def dicts_equal_modulo_keys(a: dict, b: dict) -> bool:
    return canonicalize(a) == canonicalize(b)


def dotmap_to_dict(obj):
    if isinstance(obj, DotMap):
        return {k: dotmap_to_dict(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [dotmap_to_dict(v) for v in obj]
    else:
        return obj


def ast_to_str(node: esprima.nodes.Node) -> str:
    # print(dotmap_to_dict(node))
    escodegen = pm.require("../js_utils/escodegen_wrapper.js")

    return escodegen(json.dumps(dotmap_to_dict(node)))


def pretty_print_messages(messages, indent=2, width=80):
    for i, msg in enumerate(messages):
        role = msg["role"]

        print(
            colored(
                f"\n[{i}] {role.upper()}",
                (
                    "cyan"
                    if role == "user"
                    else "green" if role == "assistant" else "yellow"
                ),
            )
        )

        if role == "assistant" and "tool_calls" in msg:
            for tc in msg["tool_calls"]:
                print(colored(f"  🔧 Tool Call: {tc['function']['name']}", "magenta"))
                try:
                    args = json.loads(tc["function"]["arguments"])
                    for k, v in args.items():
                        print(f"{' ' * indent}- {k}: {v}")
                except json.JSONDecodeError:
                    print("  (invalid JSON args)")
        elif role == "tool":
            print(colored(f"  🛠️  Tool Response for `{msg['name']}`", "magenta"))
            print(textwrap.indent(textwrap.fill(msg["content"], width), " " * indent))
        else:
            content = msg.get("content", "")
            print(textwrap.indent(textwrap.fill(content, width), " " * indent))


class AbsIntAIEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, Address):
            return str(obj)  # Convert to dictionary
        return super().default(obj)
