import esprima  # type: ignore
from cProfile import Profile
import pstats
import matplotlib.pyplot as plt
import networkx as nx
import pydot
import cython  # type: ignore
from absint_ai.Environment.Environment import Environment
import copy
from absint_ai.utils.Util import *
from absint_ai.Environment.types.Type import *
from absint_ai.Environment.memory.ConcreteHeap import *
from absint_ai.Environment.memory.AbstractHeap import *
from absint_ai.Interpreter.visitors.expr_visitor import ExprVisitor
from absint_ai.Interpreter.visitors.statement_visitor import StatementVisitor
import subprocess
from absint_ai.utils.Logger import logger, set_console_handler, set_file_handler  # type: ignore
import json
import os
import matplotlib.patches as mpatches
import absint_ai.Environment.abstractions.LLMAbstractions as LLMAbstractions
import pythonmonkey as pm
from dotmap import DotMap
from ordered_set import OrderedSet
import time
import socket

# Have to do this because the initial implementation just returns __class__ for any unregistered keys. Weird.
DotMap.__getitem__ = DotMap.get  # type: ignore


class AbsIntAI:
    def __init__(
        self,
        simplify_method: str = "allocation_sites",
        global_simplify_method: str = "allocation_sites",
        model: str = "gpt-4o-mini",
        loop_max: int = 20,
        timeout: int = 60,
        is_module: bool = True,
        should_abstract: bool = True,
        log_lines_run: bool = False,
        log_console: bool = True,
        log_file: bool= True,
        check_pointers: bool = True,  # make sure there's no pointers from abstract to concrete heap
        debug=False,
        use_agent=True
    ) -> None:
        self.model = model
        self.env = Environment()
        self.env.model = model
        self.loop_max = loop_max
        self.simplify_method = simplify_method
        self.global_simplify_method = global_simplify_method
        self.should_abstract = should_abstract
        self.check_pointers = check_pointers
        self.buggy_lines: list[str] = []
        self.buggy_line_numbers: list[str] = []
        self.prev_line: str = ""
        self.should_visualize: bool = False
        self.should_visualize_abstraction: bool = False
        self.is_module = is_module
        self.executed_lines = 0
        self.log_lines_run = log_lines_run
        self.keep_looping = True
        self.should_profile = False
        self.called_builtins_count = 0
        self.trace = []
        self.timeout = timeout * 60
        self.expr_visitor = ExprVisitor(self.env, self)
        self.statement_visitor = StatementVisitor(self.env, self, timeout=self.timeout)
        self.expr_visitor.set_statement_visitor(self.statement_visitor)
        self.statement_visitor.set_expr_visitor(self.expr_visitor)
        self.test_mode = False
        self.debug = debug
        self.identifier_info = {}
        self.log_console = log_console
        self.log_file = log_file
        self.use_agent=use_agent
        self.llm_time = 0
        
        

    def set_test_mode(self):
        self.debug = True
        self.test_mode = True

    def get_ast_from_espree(self, scriptPath: str) -> dict:
        node_path = os.environ["node_path"] # PUT YOUR PATH FOR THE nodejs BINARY HERE
        if not node_path:
            raise Exception("Node path not found!")
        result = subprocess.Popen(
            [
                node_path,
                os.path.join(os.path.dirname(__file__), "get_ast_from_espree.js"),
                scriptPath,
            ],
            stdout=subprocess.PIPE,
        )

        output, _ = result.communicate()
        ast, schema = output.decode("utf-8").split("#" * 50)
        return json.loads(ast), json.loads(schema)

    def set_simplify_method(self, method: str) -> None:
        self.simplify_method = method

    def get_loc(self, expr: esprima.nodes.Node) -> int:
        return expr.loc.start.line

    def get_loc_id(self, expr: esprima.nodes.Node) -> str:
        return f"start line: {expr.loc.start.line} start column: {expr.loc.start.column} end line: {expr.loc.end.line}"

    def initialize_openai(self, openai_key: str, base_url: str) -> None:
        self.env.initialize_openai(openai_key, base_url)

    # Runs TypeChecker on a script path and returns the environment
    def run(
        self, scriptPath: str, reset_heap_ids: bool = True, recurse: bool = True
    ) -> Environment:
        file_name = os.path.basename(scriptPath)
        if self.log_console:
            set_console_handler()
        if self.log_file:
            set_file_handler(file_name)
        logger.info(f"RUNNING on {scriptPath}")
        ast, schema = self.get_ast_from_espree(scriptPath)
        self.env.add_schema(schema)
        ast = DotMap(ast)
        start_time = time.time()
        logger.info(
            f"running with {self.simplify_method}, {self.should_abstract}, {self.model}, {self.loop_max}"
        )

        if self.should_profile:
            profiler = Profile()
            profiler.enable()
        try:
            self.start_time = time.time()
            last_allocation_site_snapshot = {}
            last_snapshot = {}
            reached_fixpoint = False
            last_values_to_llm = None
            count = 0
            while not reached_fixpoint:
                logger.info("ITERATING AGAIN")
                # if time.time() - self.start_time > self.timeout:
                #    raise Exception("Took too long")
                if count > 25:
                    raise Exception("Too many iterations")

                self.env.update_global_env(self.env.get_schema_for_scope_id("global"))
                if self.is_module:
                    # TODO Add module support
                    schema_key = f"{convert_path_to_underscore(scriptPath)}_module"
                    self.env.initialize_from_schema(
                        self.env.get_schema_for_scope_id(schema_key),
                        scriptPath,
                        schema_key,
                        self.env.global_stack_frame.get_heap_frame_address(),
                        is_module=True,
                        is_function=False,
                    )
                self.env.should_visualize = self.should_visualize
                for statement in ast["body"]:
                    self.visit_statement(statement, scriptPath)
                for entrypoint_function in self.env.entrypoint_functions:
                    logger.info(f"RUNNING ENTRYPOINT FUNCTION {entrypoint_function}")
                    self.execute_function_entry_point(entrypoint_function, scriptPath)
                if not recurse or not self.should_abstract:
                    break
                if self.should_visualize_abstraction:
                    logger.info(self.env.pretty_print_allocation_sites())
                    self.visualize("before final abstraction")
                self.env.garbage_collect(ignore_memoized=True)
                self.env.move_all_objects_to_abstract_heap()
                if (
                    self.test_mode
                ):  # Sometimes we only want to test a single iteration without abstractions.
                    break

                if self.global_simplify_method == "allocation_sites":
                    self.env.simplify_all_allocation_sites()
                elif self.global_simplify_method == "recency" or self.global_simplify_method == "depth":
                    self.env.simplify_all_objects_with_recency()
                    self.env.simplify_all_function_allocation_sites()
                # NOTE This is a naive implementation of getting the LLM to merge objects. It's not what we actually want to do, just a baseline
                elif (
                    self.global_simplify_method == "manual"
                    or self.global_simplify_method == "llm"
                ):
                    global_vars = self.env.get_all_reachable_variable_names()
                    global_allocation_sites = self.env.get_reachable_allocation_sites(
                        global_vars, ignore_functions=True
                    )
                    code = self.get_code(scriptPath)
                    if count > 0 and count % 5 == 0:
                        self.env.generate_global_merging_strategies(
                            self.global_simplify_method,
                            code,
                            global_vars,
                            global_allocation_sites,
                            num_iterations=count,
                        )
                        self.env.simplify_global(
                            self.global_simplify_method,
                            global_vars,
                            global_allocation_sites,
                        )
                    else:
                        self.env.simplify_all_objects_with_recency()
                        self.env.simplify_all_function_allocation_sites()
                elif self.global_simplify_method == "None":
                    pass
                """
                if self.global_simplify_method == "llm_naive":
                    self.env.simplify_all_allocation_sites(simplify=False)
                    changed_variables = self.env.changed_from_snapshot(last_snapshot)
                    if not len(changed_variables):
                        break
                    all_var_names = self.env.get_all_reachable_object_variable_names()
                    if len(all_var_names) == 0:
                        break
                    current_values_to_llm = (
                        self.env.get_all_reachable_object_variable_values(all_var_names)
                    )
                    if last_values_to_llm == current_values_to_llm:
                        break
                    self.env.simplify(
                        method="llm_naive",
                        model=self.model,
                        code_window=self.get_code(scriptPath),
                        changed_variables=all_var_names,
                    )
                elif self.global_simplify_method == "llm":
                    self.env.simplify_all_allocation_sites(simplify=False)
                    changed_variables = self.env.changed_from_snapshot(last_snapshot)
                    if not len(changed_variables):
                        break
                    all_var_names = self.env.get_all_reachable_object_variable_names()
                    if len(all_var_names) == 0:
                        break
                    current_values_to_llm = (
                        self.env.get_all_reachable_object_variable_values(all_var_names)
                    )
                    if last_values_to_llm == current_values_to_llm:
                        break
                    logger.info(
                        f"last values to llm {last_values_to_llm}\n current values to llm {current_values_to_llm}"
                    )
                    last_values_to_llm = current_values_to_llm
                    variables_to_abstract = (
                        LLMAbstractions.get_variables_to_abstract_from_llm(
                            model=self.model,
                            code=self.get_code(scriptPath),
                            env=self.env,
                        )
                    )
                    for variable in variables_to_abstract:
                        if (
                            variable
                            not in self.env.get_all_reachable_object_variable_names()
                        ):
                            continue
                        self.env.simplify_variable_global(
                            variable,
                            code=self.get_code(scriptPath),
                            model=self.model,
                            invoke_llm=True,
                        )
                """

                should_simplify_module = (
                    self.global_simplify_method == "allocation_sites"
                )

                self.env.finish_iteration(should_simplify_module)
                if self.should_visualize_abstraction:
                    self.visualize("after final abstraction")
                allocation_site_snapshot = self.env.allocation_sites_snapshot()
                # if there were no abstractions made
                if all(
                    [
                        len(
                            allocation_site_snapshot[allocation_site][
                                "summary_addresses"
                            ]
                        )
                        == 0
                        for allocation_site in allocation_site_snapshot
                    ]
                ):
                    break
                reached_fixpoint = True
                changed_variables = self.env.changed_from_snapshot(last_snapshot)
                if len(changed_variables):
                    if not (
                        len(changed_variables) == 1 and changed_variables[0] == "module"
                    ):
                        logger.info(f"changed variables {changed_variables}")
                        reached_fixpoint = False
                    # logger.info(
                    #   f"LOOPING AGAIN. Env is {self.env.pretty_print_all_variables()} \n last snapshot is {json.dumps(last_snapshot,indent=4)}"
                    # )

                if not recurse:
                    break
                if not reached_fixpoint and recurse:
                    last_allocation_site_snapshot = allocation_site_snapshot
                    last_snapshot = self.env.snapshot()  # snapshot
                count += 1
            if self.should_profile:
                profiler.disable()
                stats = pstats.Stats(profiler).sort_stats("cumtime")
                stats.print_stats("absint_ai")
            if reset_heap_ids:
                concrete_heap_id_reset()
                abstract_heap_id_reset()
            buggy_line_numbers_str = "\n".join(list(set(self.buggy_line_numbers)))
            bugs_str = "\n".join(list(set(self.buggy_lines)))
            logger.info(f"FINAL ENVIRONMENT")
            # logger.info(f"{self.env.pretty_print_all_variables()}")
            logger.info("\n" + bugs_str)
            logger.info("\n" + buggy_line_numbers_str)
            logger.info(f"end_time: {time.time() - start_time}")
            self.end_time = time.time() - start_time
            logger.info(
                f"There are {len(list(set(self.buggy_line_numbers)))} possible buggy lines"
            )
            logger.info(f"Total lines executed: {self.executed_lines}")
            logger.info(
                f"Number of unique lines executed: {len(list(set(self.trace)))}"
            )
            logger.info(
                f"Size of the concrete heap: {len(self.env.concrete_heap.concrete_heap)}"
            )
            logger.info(
                f"Size of the abstract heap: {len(self.env.abstract_heap.abstract_heap)}"
            )
            logger.info(
                f"Number of array builtins called: {self.called_builtins_count}"
            )
            return self.env  # return final environment
        except KeyboardInterrupt:
            buggy_line_numbers_str = "\n".join(list(set(self.buggy_line_numbers)))
            bugs_str = "\n".join(list(set(self.buggy_lines)))
            logger.info("\n" + bugs_str)
            logger.info("\n" + buggy_line_numbers_str)
            logger.info(f"end_time: {time.time() - start_time}")
            logger.info(
                f"There are {len(list(set(self.buggy_line_numbers)))} possible buggy lines"
            )
            logger.info(f"Total lines executed: {self.executed_lines}")
            logger.info(
                f"Number of unique lines executed: {len(list(set(self.trace)))}"
            )
            logger.info(
                f"Size of the concrete heap: {len(self.env.concrete_heap.concrete_heap)}"
            )
            logger.info(
                f"Size of the abstract heap: {len(self.env.abstract_heap.abstract_heap)}"
            )
            logger.info(
                f"Number of array builtins called: {self.called_builtins_count}"
            )  # logger.info(f"{list(set(self.buggy_lines))}")
            logger.info(f"SHOULD PROFILE {self.should_profile}")
            if self.should_profile:
                profiler.disable()
                stats = pstats.Stats(profiler).sort_stats("cumtime")
                logger.info(f"HERE")
                stats.print_stats("absint_ai")
            raise Exception("keyboardInterrupt")
            # sys.exit(1)

    def cur_line(self, expr: esprima.nodes.Node, file_path: str) -> str:
        cur_loc = self.get_loc(expr)
        with open(file_path, "r") as f:
            scriptContents = f.read()

            cur_line = scriptContents.split("\n")[cur_loc - 1].split("//")[0].strip()

        return f"{file_path}:{cur_line}"

    # move variables to abstract heap is true if we are inside of a function call, and we want to move everything to the abstract heap.
    def visit_statement(
        self,
        expr: esprima.nodes.Node,
        file_path: str,
    ) -> None:
        if not expr:
            return
        self.executed_lines += 1
        if self.should_visualize:
            cur_line = self.cur_line(expr, file_path)
            self.visualize(self.prev_line, cur_line)
            self.prev_line = cur_line
        if self.check_pointers:
            self.env.check_no_pointers_from_abstract_to_concrete()

        self.statement_visitor.visit(expr, file_path)
        if self.should_visualize:
            cur_line = self.cur_line(expr, file_path)
            self.prev_line = cur_line

    def visit_expression(
        self, expr: esprima.nodes.Node, file_path: str
    ) -> OrderedSet[Type]:
        if not expr:
            return OrderedSet()
        return self.expr_visitor.visit(expr, file_path)

    def execute_function_entry_point(self, func: Address, file_path: str) -> None:
        func_info = self.env.get_meta(func)
        if func_info.get("file_path"):
            function_file_path = func_info.get("file_path")
        else:
            function_file_path = file_path
        if func_info.get("__parent__"):
            self.env.initialize_from_schema(
                func_info["schema"],
                function_file_path,
                allocation_site=func_info["schema_id"],
                parent_address=func_info.get("__parent__"),
                abstract_parents=True,
            )
        else:
            self.env.initialize_from_schema(
                func_info["schema"],
                function_file_path,
                allocation_site=func_info["schema_id"],
                parent_address=self.env.get_current_heap_frame(),
                abstract_parents=True,
            )
        if func_info["body"].type == "BlockStatement":
            for statement in func_info["body"].body:
                self.visit_statement(statement, file_path)
        else:
            self.visit_expression(func_info["body"], file_path)
        self.env.return_from_function(None, file_path)

    def get_address(self, expr: esprima.nodes.Node, file_path: str) -> OrderedSet[Type]:
        if expr.type == "Identifier":
            return self.env.lookup(expr.name).get_all_values()
        if expr.type == "MemberExpression":
            if expr.object.type == "ThisExpression":
                return self.env.lookup("this").get_all_values()
            return self.visit_expression(expr, file_path)
        return OrderedSet([baseType.TOP])

    def get_num_input_tokens(self):
        return self.env.num_input_tokens

    def get_num_output_tokens(self):
        return self.env.num_output_tokens

    def visualize(self, prev_statement: str = "", cur_statement: str = "") -> None:
        self.env.visualize(
            prev_statement=prev_statement.split("/")[-1],
            cur_statement=cur_statement.split("/")[-1],
        )

    def get_code_window_around_expr(
        self, expr: esprima.nodes.Node, file_path: str, window_size: int
    ):
        expr_start = expr.loc.start.line
        expr_end = expr.loc.end.line
        # window_size = expr_end - expr_start
        with open(file_path, "r") as f:
            lines = f.readlines()
            start = max(0, expr_start - window_size)
            end = min(len(lines), expr_end + window_size)
            return "".join(lines[start:end])

    # returns the entire file we're analyzing as a string
    def get_code(self, script_path: str):
        with open(script_path, "r") as f:
            return f.read()
