import json
from openai import OpenAI
from typing import List, Dict, Any
import os
from abc import ABC, abstractmethod

from ia_rag.prompts.workflow_prompts import Executor_Agent_Prompt as Interact_Exe_Prompt, \
    Reasoner_Agent_Prompt as Interact_Reason_Prompt, \
    First_Reasoner_Prompt as Interact_First_Reason_Prompt, \
    Plan_Agent_Prompt as Interact_Plan_Prompt


from ia_rag.prompts.tool_info import INTERACT_TOOL

from ia_rag.utils.rag_action import ActionClient, VDB_SERVER_URL
from ia_rag.utils.log_helper import transform_msg_to_json
from ia_rag.utils.llm_call import create_chat_completion_with_retry


MAX_ITERATIONS = 7


class WorkflowAgentBase(ABC):

    def __init__(self, llm_api_base: str, model_name: str, api_key="dummy"):
        self.llm_client = OpenAI(base_url=llm_api_base, api_key=api_key)
        self.model_name = model_name
        self._setup_prompts_and_tools()

        try:
            print("Initializing ActionClient...")
            self.action_client = ActionClient(server_base_url=VDB_SERVER_URL)
        except Exception as e:
            print(f"Failed to initialize ActionClient: {e}")
            raise

    @abstractmethod
    def _setup_prompts_and_tools(self):
        self.act_prompt = None
        self.thought_prompt = None
        self.plan_prompt = None
        self.first_thought = None
        self.tools = []

    def _perform_rag_action(self, args) -> str:
        try:
            return self.action_client.execute_search_plan(**args)
        except Exception as e:
            print(f"[_perform_rag_action] Action exception: {e}")
            return f"[_perform_rag_action] exception: {e}"

    def _create_chat_completion_with_retry(self, **kwargs):
        return create_chat_completion_with_retry(
            llm_client=self.llm_client,
            model_name=self.model_name,
            extra_body={"enable_thinking": False},
            **kwargs
        )

    def run(self, question: str, plan_str: str = ""):
        user_input = "The user's query is:\n" + question

        messages = [
            {"role": "system", "content": self.thought_prompt},
            {"role": "user", "content": user_input},
        ]

        if plan_str:
            messages.append({"role": "assistant", "content": plan_str})
        else:
            plan_msg = self._create_chat_completion_with_retry(
                messages=[{"role": "system", "content": self.plan_prompt},
                          {"role": "user", "content": user_input}],
                temperature=0.6,
            )
            messages.append(plan_msg)

        for iteration in range(MAX_ITERATIONS):
            if iteration == 0:
                messages[0]["content"] = self.first_thought
            else:
                messages[0]["content"] = self.thought_prompt
            thought_response_msg = self._create_chat_completion_with_retry(
                messages=messages,
                temperature=0.6,
            )
            messages.append(thought_response_msg)

            messages[0]["content"] = self.act_prompt
            act_response_msg = self._create_chat_completion_with_retry(
                messages=messages,
                temperature=0.6,
                tools=self.tools,
                tool_choice="auto",
            )
            messages.append(act_response_msg)

            tool_calls = getattr(act_response_msg, 'tool_calls', None)
            if tool_calls:
                for tool_call in tool_calls:
                    func_name = tool_call.function.name
                    args = json.loads(tool_call.function.arguments)
                    tool_context = self._perform_rag_action(args)

                    tool_res_msg = {
                        "role": "tool",
                        "tool_call_id": tool_call.id,
                        "name": func_name,
                        "content": tool_context,
                    }
                    messages.append(tool_res_msg)
            elif getattr(act_response_msg, 'content', None):
                break

        json_messages = transform_msg_to_json(messages)
        return json_messages, iteration + 1


class WorkflowInteract(WorkflowAgentBase):

    def _setup_prompts_and_tools(self):
        self.act_prompt = Interact_Exe_Prompt
        self.thought_prompt = Interact_Reason_Prompt
        self.plan_prompt = Interact_Plan_Prompt
        self.first_thought = Interact_First_Reason_Prompt
        self.tools = [INTERACT_TOOL]
        print("WorkflowInteract initialized in 'interact' mode.")


if __name__ == "__main__":

    LLM_BASE_URL = "http://localhost:8001/v1"
    MODEL_NAME = "xxx"

    complex_question = 'Who wears the football boot made by Nike, said to be for traction and agility, designed for deceptive players, and is a forward for Premier League club Tottenham Hotspur?'

    print("\n" + "*"*20 + " RUNNING WorkflowInteract " + "*"*20)
    agent_interact = WorkflowInteract(llm_api_base=LLM_BASE_URL,
                                      model_name=MODEL_NAME)

    json_messages_interact, iterations_interact = agent_interact.run(
        question=complex_question)
