import ast
from typing import Any, Callable
import importlib.util
import os
from RoboMemory.BaseModules.BaseFunction import BaseFunction

class APIProcesser:
   

    def __init__(
        self,
        api_file_path: str,
        function_filter_list: list[str] = None
    ) -> None:
       
        self.api_file_path = api_file_path
        self.function_filter_list = function_filter_list
        self.api_module = self.__import_api_module()
        self.function_name_list = self.__get_functions_name_from_file(api_file_path)
        self.function_list = self.__get_all_callable()
        self.basefunction_list = [BaseFunction(func) for func in self.function_list]

    def get_base_fn_list(self) -> list:
      
        return self.basefunction_list

    def __get_api_description_list(self) -> list[dict]:
      
        return [function.get_function_infos() for function in self.basefunction_list]

    def get_api_description(self) -> list[dict]:
       
        action_list = self.__get_api_description_list()
        
        action_str_list = []
        
        for action in action_list:
            if 'params' in action:
                action_str_list.append(f"{action['function_expression']}:\n\t- {action['description']}\n\t- params: {action['params']}")
            else:
                action_str_list.append(f"{action['function_expression']}:\n\t- {action['description']}")
                
        
        actions = "\n\n".join(action_str_list)
        
        return actions

    def __get_functions_name_from_file(self, filename: str) -> list:
       
        with open(filename, "r") as fp:
            tree = ast.parse(fp.read())
        functions = [node.name for node in ast.walk(tree) if isinstance(node, ast.FunctionDef)]
        if isinstance(self.function_filter_list, list):
            functions = [fn for fn in functions if fn not in self.function_filter_list]
        return functions

    def __get_all_callable(self) -> list:
       
        return [self.__get_callable_from_module(name) for name in self.function_name_list]

    def __get_callable_from_module(self, callable_name: str) -> Callable[..., Any]:
     
        try:
            callable_obj = getattr(self.api_module, callable_name)
            if callable(callable_obj):
                return callable_obj
            else:
                raise ValueError(f"{callable_name} is not callable")
        except AttributeError as e:
            raise AttributeError(f"Function {callable_name} not found in module {self.api_module}") from e

    def __import_api_module(self) -> Any:
     
        module_name = os.path.splitext(os.path.basename(self.api_file_path))[0]
        spec = importlib.util.spec_from_file_location(module_name, self.api_file_path)
        if spec is None:
            raise ImportError(f"Cannot load module from {self.api_file_path}")
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        return module