import ast

import os
import sys
# project_path = os.path.abspath(
#     os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir, os.pardir, os.pardir)
# )
# if project_path not in sys.path:
#     sys.path.append(project_path)



from fuzzywuzzy import process as fuzzy_process
from agent.nesy_agent.nesy_utils.tag_enum import Attraction_Tags, Hotel_Tags

class FunctionValueTracker(ast.NodeVisitor):
    """ 
    Static analysis tool, used to track the call value of the specified function and verify whether the final comparison is legal.
    Note that this tool uses many idealized assumptions, including but not limited to:
    - Variable names in code will not be repeated defined or overwritten
    - Elements added to iterable objects will not be modified or deleted
    - Elements added to iterable objects will not be modified or deleted
    - Iterable objects are List or Set and use append or add methods to add elements
    - Comparison operands are constant values, lists, or sets
    - Bool functions without special comparison
    - ...
    """

    def __init__(self, target_func, valid_values):
        self.target_func = target_func  # The target function name to track
        self.valid_values = set(valid_values)  # The allowed legal value set
        self.comparisons = []  # Save comparison cases
        self.errors = []  # Save illegal comparisons
        self.usage_paths = []  # Save the usage scenarios of function return values
        self.assignments = {}  # Save assignment records

    def reset(self):
        """
        Reset the analyzer state.
        """
        self.comparisons = []
        self.errors = []
        self.usage_paths = []
        self.assignments = {}

    def visit_Assign(self, node):
        """
        Check if the return value of the target function is assigned to a variable.
        """
        # If the id of the assigned value is in assignments
        if isinstance(node.value, ast.Name) and node.value.id in self.assignments:
            # Add to assignments
            for target in node.targets:
                if isinstance(target, ast.Name):
                    self.assignments[target.id] = ast.unparse(node)
        # If the assigned value is the return value of the target function
        if (
            isinstance(node.value, ast.Call)
            and isinstance(node.value.func, ast.Name)
            and node.value.func.id == self.target_func
        ):
            # Assign the return value of the target function to the target variable
            for target in node.targets:
                if isinstance(target, ast.Name):
                    self.assignments[target.id] = ast.unparse(node)

        self.generic_visit(node)

    def visit_Compare(self, node):
        """
        Check if the return value of the target function or variable participates in comparison, whether on the left or right.
        """
        compared_values = set()
        involved = False
        check_pos = ""
        # Check the left operand
        if isinstance(node.left, ast.Call) and node.left.func.id == self.target_func:
            involved = True
            check_pos = "right"
        elif isinstance(node.left, ast.Name) and node.left.id in self.assignments:
            involved = True
            check_pos = "right"

        # Check the right operand
        for comparator in node.comparators:
            if (
                isinstance(comparator, ast.Call)
                and isinstance(comparator.func, ast.Name)
                and comparator.func.id == self.target_func
            ):
                involved = True
                check_pos = "left"
            elif isinstance(comparator, ast.Name) and comparator.id in self.assignments:
                involved = True
                check_pos = "left"

        # If the return value of the target function participates in comparison, then process
        if involved:
            if check_pos == "left":
                compared_values = self._extract_compared_values([node.left])
            elif check_pos == "right":
                compared_values = self._extract_compared_values(node.comparators)
            # Save comparison records
            self.comparisons.append(
                {"code": ast.unparse(node), "compared_values": compared_values}
            )
            # Check validity
            if not compared_values.issubset(self.valid_values):
                self.errors.append(
                    {
                        "code": ast.unparse(node),
                        "invalid_values": compared_values - self.valid_values,
                    }
                )

        self.generic_visit(node)

    def visit_Call(self, node):
        """
        Check the usage scenarios of the return value of the target function, including direct call, assignment to variable, and passing to method (such as append, add).
        """
        # Check if the target function is called directly
        if isinstance(node.func, ast.Name) and node.func.id == self.target_func:
            self.usage_paths.append(
                {
                    "type": "function_call",
                    "code": ast.unparse(node),
                }
            )

        # Check method calls, such as append and add
        if isinstance(node.func, ast.Attribute):
            method_target = node.func.value  # The target object of the method call

            # Check the append method of the list
            if (
                node.func.attr == "append"
                and isinstance(method_target, ast.Name)
                and node.args
            ):
                # Check if the parameter is the call of the target function or the tracked variable
                if (
                    isinstance(node.args[0], ast.Call)
                    and isinstance(node.args[0].func, ast.Name)
                    and node.args[0].func.id == self.target_func
                ) or (
                    isinstance(node.args[0], ast.Name)
                    and node.args[0].id in self.assignments
                ):
                    self.usage_paths.append(
                        {
                            "type": "list_append",
                            "list": method_target.id,
                            "value": ast.unparse(node.args[0]),
                            "code": ast.unparse(node),
                        }
                    )
                    self.assignments[method_target.id] = ast.unparse(node)

            # Check the add method of the set
            elif (
                node.func.attr == "add"
                and isinstance(method_target, ast.Name)
                and node.args
            ):
                # Check if the parameter is the call of the target function or the tracked variable
                if (
                    isinstance(node.args[0], ast.Call)
                    and isinstance(node.args[0].func, ast.Name)
                    and node.args[0].func.id == self.target_func
                ) or (
                    isinstance(node.args[0], ast.Name)
                    and node.args[0].id in self.assignments
                ):
                    self.usage_paths.append(
                        {
                            "type": "set_add",
                            "set": method_target.id,
                            "value": ast.unparse(node.args[0]),
                            "code": ast.unparse(node),
                        }
                    )
                    self.assignments[method_target.id] = ast.unparse(node)

        self.generic_visit(node)

    def _extract_compared_values(self, comparators):
        """
        Extract constant values from comparison operands, supporting lists, sets, and single values.
        """
        compared_values = set()
        for comparator in comparators:
            if isinstance(comparator, ast.List):  # List comparison
                compared_values.update(
                    elt.value
                    for elt in comparator.elts
                    if isinstance(elt, ast.Constant)
                )
            elif isinstance(comparator, ast.Set):  # Set comparison
                compared_values.update(
                    elt.value
                    for elt in comparator.elts
                    if isinstance(elt, ast.Constant)
                )
            elif isinstance(comparator, ast.Constant):  # Single value comparison
                compared_values.add(comparator.value)
        return compared_values


class CodeBlockChecker:
    def __init__(self, func_name_list, valid_values_list, need_fuzzy_list=[]):
        assert len(func_name_list) == len(valid_values_list)
        self.func_name_list = func_name_list
        self.valid_values_list = valid_values_list
        self.need_fuzzy_list = need_fuzzy_list
        self.trackers = {}
        for i in range(len(func_name_list)):
            self.trackers[func_name_list[i]] = FunctionValueTracker(
                func_name_list[i], valid_values_list[i]
            )

    def reset(self):
        for tracker in self.trackers.values():
            tracker.reset()

    def check(self, code):  # Check the function calls and comparisons in the code block
        tree = ast.parse(code)  # Parse the code into AST, representing the syntax structure of the code in a tree structure (such as function definitions, conditional statements, variable assignments, etc. correspond to different nodes in the tree)
        error_info = []
        for func_name, tracker in self.trackers.items():
            tracker.reset()
            tracker.visit(tree)  # Call the corresponding processing method according to different types of nodes (node)
            # Check the assignment statements Assign node, comparison operators Compare node, function calls Call node, etc.
            for error in tracker.errors:
                error_code = error["code"]
                invalid_values = error["invalid_values"]
                error_info_str = f"invalid compare value in code: {error_code}, the invalid values compared with func {tracker.target_func}'s return value: {invalid_values}. Either change the func or the compare values."
                if func_name in self.need_fuzzy_list:  # Need fuzzy matching   Attraction location
                    valid_values = tracker.valid_values
                    valid_values_list = list(valid_values)
                    for invalid_value in invalid_values:
                        fuzzy_result = fuzzy_process.extract(
                            invalid_value,
                            valid_values_list,
                            limit=len(valid_values_list),
                        )
                        best_score = fuzzy_result[0][1]
                        res_fuzzy_result = []
                        for result in fuzzy_result:
                            if result[1] == best_score or len(res_fuzzy_result) < 10:
                                res_fuzzy_result.append(result[0])
                        error_info_str += f"For invalid value: {invalid_value}, the most similar valid values are: {res_fuzzy_result}"
                error_info.append(error_info_str)
        return error_info, {
            func_name: tracker.errors for func_name, tracker in self.trackers.items()
        }


class HardLogicPyChecker(CodeBlockChecker):

    def __init__(self, locale):
        # Check if the function return values used in the generated code are predefined legal values
        func_name_list = [
            "activity_type",
            "attraction_type",
            "hotel_type",
            "innercity_transport_type",  # Inner city transportation type
            "intercity_transport_type",  # Inter city transportation type
        ]
        valid_values_list = [
            [
                "poi",
                "hotel",
                "transportation",
            ],
            Attraction_Tags[locale],
            Hotel_Tags,
            ["metro", "taxi", "walk"],
            ["train", "flight"],
        ]

        super().__init__(func_name_list, valid_values_list)




