import os
import ast
import importlib.util
import types
from typing import Callable, Union, List, Set, Tuple
from collections import deque, defaultdict


class BaseFunctionManager:
    def __init__(self, function_path: str, class_name: str = 'FunctionCollector'):

        self.function_path = function_path  
        self.class_name = class_name        
        os.makedirs(self.function_path, exist_ok=True)  #       
        self.function_name_list: List[str] = []  #           
        self.namespace = {}  #                  
        self._update_function_list()  #           
        self._reload_namespace()      #           

    def _update_function_list(self) -> None:
  
        self.function_name_list = []
        for filename in os.listdir(self.function_path):
            #     .py     __init__.py 
            if filename.endswith('.py') and filename != '__init__.py':
                func_name = filename[:-3]  #   .py  
                self.function_name_list.append(func_name)

    def _validate_function_str(self, func_str: str) -> Tuple[bool, str]:

        try:
            tree = ast.parse(func_str)  #   AST  
        except SyntaxError as e:
            return False, f"grammer error: {e}"
        
        function_defs = []
        for node in tree.body:
            if isinstance(node, ast.FunctionDef):
                function_defs.append(node)  #       
            elif isinstance(node, (ast.Import, ast.ImportFrom)):
                continue  #       
            elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Str):
                continue  #        
            else:
                return False, "illigle"
        
        #              
        if len(function_defs) != 1:
            return False, "must have one def"
        return True, ""

    def _collect_dependent_functions(self, func_names: List[str]) -> Set[str]:

        reverse_deps = defaultdict(list)  #      
        for func_name in self.function_name_list:
            file_path = os.path.join(self.function_path, f"{func_name}.py")
            if not os.path.exists(file_path):
                continue
            with open(file_path, 'r', encoding='utf-8') as f:
                code = f.read()
            try:
                tree = ast.parse(code)
            except SyntaxError:
                continue
            
            #         
            called_funcs = set()
            for node in ast.walk(tree):
                if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
                    called_funcs.add(node.func.id)
            
            #        
            for called in called_funcs:
                reverse_deps[called].append(func_name)
        
        # BFS        
        visited = set()
        queue = deque(func_names)
        while queue:
            current = queue.popleft()
            if current not in visited:
                visited.add(current)
                for dependent in reverse_deps.get(current, []):
                    if dependent not in visited:
                        queue.append(dependent)
        return visited

    def _reload_namespace(self) -> None:
 
        self.namespace.clear()  #         
        dep_graph = defaultdict(list)  #      
        func_names = self.function_name_list.copy()
        
        #      
        for func_name in func_names:
            file_path = os.path.join(self.function_path, f"{func_name}.py")
            if not os.path.exists(file_path):
                continue
            with open(file_path, 'r', encoding='utf-8') as f:
                code = f.read()
            try:
                tree = ast.parse(code)
            except SyntaxError:
                continue
            
            #         
            called_funcs = set()
            for node in ast.walk(tree):
                if isinstance(node, ast.Call) and isinstance(node.func, ast.Name):
                    called = node.func.id
                    if called in func_names:  #           
                        called_funcs.add(called)
            dep_graph[func_name] = list(called_funcs)
        
        #     
        visited = set()
        result = []
        def visit(node):
       
            
            if node in visited:
                return
            visited.add(node)
            for neighbor in dep_graph[node]:
                visit(neighbor)
            result.append(node)
        
        for node in func_names:
            if node not in visited:
                visit(node)
        
        #          
        for func_name in result:
            file_path = os.path.join(self.function_path, f"{func_name}.py")
            if not os.path.exists(file_path):
                continue
            with open(file_path, 'r', encoding='utf-8') as f:
                code = f.read()
            try:
                exec(code, self.namespace)  #         
            except Exception as e:
                raise RuntimeError(f"load function {func_name} error: {e}")

    def add_function(self, function_strs: Union[str, List[str]]) -> Union[bool, List[bool]]:
 
        if isinstance(function_strs, str):
            function_strs = [function_strs]
        results = []
        for func_str in function_strs:
            #       
            valid, msg = self._validate_function_str(func_str)
            if not valid:
                results.append(False)
                continue
            
            try:
                #      
                tree = ast.parse(func_str)
                func_def = next(n for n in ast.walk(tree) if isinstance(n, ast.FunctionDef))
                func_name = func_def.name
            except:
                results.append(False)
                continue
            
            #     
            file_path = os.path.join(self.function_path, f"{func_name}.py")
            try:
                with open(file_path, 'w', encoding='utf-8') as f:
                    f.write(func_str)
            except:
                results.append(False)
                continue
            
            #        
            try:
                exec(func_str, self.namespace)
                results.append(True)
            except Exception as e:
                os.remove(file_path)  #          
                results.append(False)
        
        self._update_function_list()  #       
        return results if len(results) > 1 else results[0]

    def del_function(self, function_names: Union[List[str], str]) -> None:
  
        if isinstance(function_names, str):
            function_names = [function_names]
        
        #                   
        to_delete = self._collect_dependent_functions(function_names)
        
        #              
        for func_name in to_delete:
            file_path = os.path.join(self.function_path, f"{func_name}.py")
            if os.path.exists(file_path):
                os.remove(file_path)
            if func_name in self.namespace:
                del self.namespace[func_name]
        
        self._update_function_list()  #       

    def get_function(self, function_names: Union[List[str], str]) -> Union[Callable, List[Callable]]:

        if isinstance(function_names, str):
            function_names = [function_names]
        
        funcs = []
        for name in function_names:
            if name not in self.namespace:
                raise ValueError(f"Function {name} not exit")
            funcs.append(self.namespace[name])
        
        return funcs[0] if len(funcs) == 1 else funcs

    def check_function_exist(self, function_names: Union[List[str], str]) -> Union[bool, List[bool]]:

        if isinstance(function_names, str):
            return function_names in self.function_name_list
        return [name in self.function_name_list for name in function_names]