import importlib
import json
from typing import List, Tuple

from langchain_core.tools import BaseTool

from kgot.controller import ControllerInterface
from kgot.controller.networkX.directRetrieve.llm_invocation_handle import (
    define_final_solution,
    define_forced_retrieve_queries,
    define_math_tool_call,
    define_need_for_math_before_parsing,
    define_next_step,
    define_tool_calls,
    define_write_query_given_new_information,
    fix_code,
    generate_forced_solution,
    merge_reasons_to_insert,
    parse_solution_with_llm,
)
from kgot.tools.PythonCodeTool import RunPythonCodeTool
from kgot.utils import State
from kgot.utils.utils import ensure_file_path_exists, is_empty_solution


class Controller(ControllerInterface):
    """
    Controller class for managing the interaction between the LLM and the NetworkX graph storage.
    This class is responsible for executing tasks, retrieving information, and managing the state of the knowledge graph.

    The retrieve method is direct retrieve, meaning it directly loads the content of the NetworkX graph storage into the LLM context.

    Attributes:
        graph (State.knowledge_graph): The knowledge graph instance for interacting with the NetworkX graph storage.
        usage_statistics (State.usage_statistics): The usage statistics instance for tracking usage.
        tool_manager (ToolManager): The tool manager instance for managing tools.
        tools (list[BaseTool]): The list of tools available for use.
        tool_names (dict): A dictionary mapping tool names to their corresponding tool instances.
        llm_execution (BaseTool): The LLM execution tool instance.
        llm_math_executor (BaseTool): The LLM math executor tool instance.
    
    For other attributes, see ControllerInterface.
    """

    def __init__(self,
                 python_executor_uri: str,
                 llm_execution_model: str ,
                 llm_execution_temperature: float,
                 statistics_file_name: str,
                 db_choice: str = "networkX",
                 tool_choice: str = "tools_v2_3",
                 *args, **kwargs) -> None:
        """
        Initializes the Controller with the specified parameters.
        Args:
            python_executor_uri (str): The URI of the Python tool executor.
            llm_execution_model (str): The model used for executing tasks.
            llm_execution_temperature (float): The temperature parameter for the execution model.
            statistics_file_name (str): The name of the file for storing usage statistics.
            db_choice (str): The choice of database to use.
            tool_choice (str): The choice of tool to use.
            *args: Additional arguments. See ControllerInterface for details.
            **kwargs: Additional keyword arguments. See ControllerInterface for details.
            """
        super().__init__(llm_execution_model=llm_execution_model, llm_execution_temperature=llm_execution_temperature, *args, **kwargs)

        ensure_file_path_exists(statistics_file_name)
        self.graph = State.knowledge_graph(db_choice)
        self.usage_statistics = State.usage_statistics(statistics_file_name)

        # Set up the tools
        tool_manager = importlib.import_module(f"kgot.tools.{tool_choice}").ToolManager
        self.tool_manager = tool_manager(self.usage_statistics, python_executor_uri=python_executor_uri)
        self.tools = self.tool_manager.get_tools()
        # Create a map between the tools and their names
        self.tool_names = {}
        for curr_tool in self.tools:
            self.tool_names[curr_tool.name.lower()] = curr_tool
            self.logger.info(f"Provided Tool: {curr_tool} {curr_tool.name} {curr_tool.args}")

        # Bind the LLM execution to the tools (it's the main difference with the LLM planning)
        if self.tools:
            self.llm_execution = self.llm_execution.bind_tools(self.tools, tool_choice="required")

        pythonTool = RunPythonCodeTool(try_to_fix=True,
                                                   times_to_fix=3,
                                                   model_name=llm_execution_model,
                                                   temperature=llm_execution_temperature,
                                                   python_executor_uri=python_executor_uri,
                                                   usage_statistics=self.usage_statistics)
        controller_python: list[BaseTool] = [pythonTool]
        
        self.llm_math_executor = self.llm_math_executor.bind_tools(controller_python, tool_choice="required")

    def _iterative_next_step_logic(self, problem: str, *args, **kwargs) -> Tuple[str, int]:
        solution: str = ""  # Final solution
        raw_solutions: List[str] = []
        existing_entities_and_relationships = ""
        tool_calls_made = []
        current_iteration = 0

        while current_iteration < self.max_iterations:
            retrieve_next_step = {
                "RETRIEVE": 0,
                "INSERT": 0,
                "RETRIEVE_CONTENT": [],
                "INSERT_CONTENT": []
            }
            reason_to_insert = ""
            query_type = ""
            retrieve_query = ""

            for i in range(self.num_next_steps_decision):
                retrieve_query, query_type = define_next_step(self.llm_planning, problem,
                                                              existing_entities_and_relationships, tool_calls_made,
                                                              self.usage_statistics)
                print(f"returned next step {query_type}, {retrieve_query}")
                try:
                    retrieve_next_step[query_type] += 1
                    retrieve_next_step[query_type + "_CONTENT"].append(retrieve_query)
                except KeyError:
                    self.logger.error(f"Unknown query type for next step: {query_type} at iteration {i}")

            if retrieve_next_step["RETRIEVE"] > retrieve_next_step["INSERT"]:
                for sol in retrieve_next_step["RETRIEVE_CONTENT"]:
                    self.logger.info(f"Current partial solution for retrieve: {sol}")
                    raw_solutions.append(sol)
                break

            reason_to_insert = retrieve_next_step["INSERT_CONTENT"][0] if retrieve_next_step["INSERT"] > 0 else ""
            if retrieve_next_step["INSERT"] > 1:
                reason_to_insert = merge_reasons_to_insert(self.llm_planning, retrieve_next_step["INSERT_CONTENT"],
                                                            self.usage_statistics)
            print(f"Reason to insert: {reason_to_insert}")

            existing_entities_and_relationships = self._insert_logic(problem, reason_to_insert, tool_calls_made,
                                   existing_entities_and_relationships)

            current_iteration += 1
            print(f"Current iteration: {current_iteration}")

        solution = self._retrieve_logic(problem, existing_entities_and_relationships, current_iteration, raw_solutions)

        return solution, current_iteration
          
    def _insert_logic(self, query: str, reason_to_insert: str, tool_calls_made: List[str], existing_entities_and_relationships: str, *args, **kwargs) -> str:
        tool_calls = define_tool_calls(self.llm_execution, query, existing_entities_and_relationships,
                                               reason_to_insert, tool_calls_made, self.usage_statistics)
        print(f"Tool_calls: {tool_calls}")

        tools_results = self._invoke_tools_after_llm_response(tool_calls)
        tool_calls_made.extend(tool_calls)
        for call, result in zip(tool_calls, tools_results):
            new_information = f"function '{call}' returned: '{result}'"

            new_information_code_queries = define_write_query_given_new_information(
                self.llm_planning,
                query,
                existing_entities_and_relationships,
                new_information,
                reason_to_insert,
                self.usage_statistics)

            for single_query in new_information_code_queries:
                write_response = self.graph.write_query(single_query)
                self.logger.info(f"Write query result: {write_response}")
                retry_i = 0
                while not write_response[
                    0] and retry_i < self.max_cypher_fixing_retry:  # Failed the insert query
                    retry_i += 1
                    self.logger.info(
                        f"Failed the write query. Retry number: {retry_i} out of {self.max_cypher_fixing_retry}")

                    self.logger.error(
                        f"Trying to fix error encountered when executing Python query: {single_query}\nError: {write_response[1]}")
                    single_query = fix_code(self.llm_planning, single_query, write_response[1], existing_entities_and_relationships, self.usage_statistics)
                    write_response = self.graph.write_query(single_query)
                    self.logger.info(f"Write query result after fixing: {write_response}")

            existing_entities_and_relationships = self.graph.get_current_graph_state()

            print(f"All nodes and relationships after {call}:\n {existing_entities_and_relationships}")
        
        return existing_entities_and_relationships

    def _retrieve_logic(self, query: str, existing_entities_and_relationships: str, current_iteration: int, solutions: List[str]) -> str:
        if current_iteration == self.max_iterations and len(solutions) == 0:
            self.logger.info("Maximum iterations reached without finding a solution. Forcing generation of retrieve queries.")
            print("Maximum iterations reached without finding a solution. Forcing generation of retrieve queries.")

            for i in range(self.num_next_steps_decision):
                retrieve_query = define_forced_retrieve_queries(self.llm_planning, query,
                                                                existing_entities_and_relationships,
                                                                self.usage_statistics)
                solutions.append(retrieve_query)
        # Now, proceed to parse solutions if not empty and choose the best one
        if not is_empty_solution(solutions):
            array_parsed_solutions = []
            for sol in solutions:
                self.logger.info(f"Current partial solution for math need: {sol}")
                need_math = define_need_for_math_before_parsing(self.llm_planning, query, sol, self.usage_statistics)
                if need_math:
                    sol = self._get_math_response(query, sol)

                for i in range(self.max_final_solution_parsing):
                    array_parsed_solutions.append(
                        parse_solution_with_llm(self.llm_planning, query, sol, self.gaia_formatter, self.usage_statistics))
            # Check if all the parsed solutions are empty
            if all(not parsed_sol.strip() for parsed_sol in array_parsed_solutions if parsed_sol):
                self.logger.info("All parsed solutions are empty. Forcing generation of a solution.")
                print("All parsed solutions are empty. Forcing generation of a solution.")
                # Force generation of a solution
                forced_solution = generate_forced_solution(self.llm_planning, query, existing_entities_and_relationships, self.usage_statistics)
                solution = parse_solution_with_llm(self.llm_planning, query, forced_solution, self.gaia_formatter, self.usage_statistics)
            else:
                # We have a series of solutions, we need to choose the best one
                self.logger.info(f"Solution list for final solution choose: {str(solutions)} {str(array_parsed_solutions)}")
                solution = define_final_solution(self.llm_planning, query, str(solutions), array_parsed_solutions,
                                                 self.usage_statistics)
        else:
            # Returned empty retrieves, force generation of a (textual) solution
            self.logger.info("No solutions found after maximum iterations and forced additional retrieve attempts. Forcing generation of a solution.")
            print("No solutions found after maximum iterations and forced additional retrieve attempts. Forcing generation of a solution.")

            forced_solution = generate_forced_solution(self.llm_planning, query, existing_entities_and_relationships, self.usage_statistics)
            solution = parse_solution_with_llm(self.llm_planning, query, forced_solution, self.gaia_formatter, self.usage_statistics)

        return solution

    def _get_math_response(self, query, solution):
        python_tool_call = define_math_tool_call(self.llm_math_executor, query, solution, self.usage_statistics)
        math_solution = None
        math_solution = self._invoke_tools_after_llm_response(python_tool_call)

        self.logger.info(
            f"Retrieve solution parsing from Python math: {math_solution}")
        if len(math_solution) > 0 and math_solution[0] is not None:
            solution = f"{solution}\n In addition, this is the response given by Python after calculations. Use the numbers and the logic as you see fit. <python_math_solution>{math_solution}<\\python_math_solution>."

        return solution

    def _invoke_tools_after_llm_response(self, tool_calls):
        outputs = []

        # Iterate through each tool call and invoke the appropriate tool or answer using the cached answer
        for tool_call in tool_calls:
            self.logger.info(f"Current tool_call: {tool_call}")
            tool_name = tool_call['name'].lower()
            tool_args = tool_call['args']
            self.logger.info(f"Current tool_args: {tool_args}")

            # Create a cache key based on tool name and arguments
            tool_call_key = (tool_name, json.dumps(tool_args, sort_keys=True))

            # Check if result is in cache
            if tool_call_key in self.tool_call_results_cache:
                tool_output = self.tool_call_results_cache[tool_call_key]
                self.logger.info(f"Retrieved tool output from cache for tool '{tool_name}' with args: {tool_args}")
            else:
                selected_tool = self.tool_names.get(tool_name)
                if selected_tool:
                    tool_output = self._invoke_tool_with_retry(selected_tool, tool_args)

                    self.logger.info(f"Tool '{tool_name}' output: {tool_output}")

                    if not tool_name == 'extract_zip':
                        self.tool_call_results_cache[tool_call_key] = tool_output
                else:
                    self.logger.warning(f"Tool '{tool_name}' not found.")
                    tool_output = None

            outputs.append(tool_output)

        return outputs
