import time
from litellm import completion
from litellm.exceptions import RateLimitError
from typing import List, Optional, Dict, Any

from src.agents.base import Agent
from src.envs.base import Env
from src.types import AgentRunResult
from src.utils import get_action

TOOL_CALLING_INSTRUCTION = """# Instruction
- You are a DB agent that assists a user in retrieving data from an EHR database.
- You are currently engaged in a conversation with a user who wants to retrieve some data or statistics from an EHR database.
- If the user's request is ambiguous or lacks crucial details (e.g., filtering criteria), ask clarifying questions in plain language to better understand the user's request.
- You are provided with a set of tools to better assist the user.
- List of tools:
  - table_search: search for tables in the database
  - column_search: search for columns in a table
  - value_substring_search: search for values in a table by substring
  - value_similarity_search: search for values in a table by similarity
  - sql_execute: execute a SQL query
  - web_search: search the web for background clinical knowledge
- Before writing a SQL query, you need to explore the database schema and stored values to fill the gaps between the user's request and the database content.
- Use table_search and column_search to explore the database schema to find the most relevant tables and columns.
- Use value_substring_search and value_similarity_search to explore stored values to find the most relevant entries.
- Clinical concepts (e.g., diagnoses, procedures, medications, lab tests) in the database may not match the user's words exactly. Use value search tools to explore relevant entries. If value_substring_search returns no results, try value_similarity_search, and vice versa.
- If you need clinical knowledge beyond what you already know (e.g., mechanism of action of a drug, or alternative lab tests), use web_search to find background clinical knowledge.
- Never invent or assume information not provided by the user or retrieved using the tools.
- You may only make one tool call at a time. Do not send a user response in the same turn as a tool call.
- Only after you have gathered all necessary information from the user or database, use the sql_execute tool to write and execute a single valid SQL query that fully addresses the user's latest request.
- When you write a SQL query, you must use sql_execute to execute the SQL query and deliver the results to the user. The user does not have access to the database."""

TOOL_SETS = ['table_search', 'column_search', 'sql_execute', 'value_substring_search', 'value_similarity_search', 'web_search']

class ToolCallingAgent(Agent):
    def __init__(
        self,
        tools_info: List[Dict[str, Any]],
        rule: str,
        model: str,
        api_base: Optional[str] = None,
        temperature: float = 0.0
    ):
        self.tools_info = [tool for tool in tools_info if tool['function']["name"] in TOOL_SETS]
        self.rule = rule
        self.model = model
        self.api_base = api_base
        self.temperature = temperature
        self.instruction = TOOL_CALLING_INSTRUCTION + '\n' + self.rule
    def run(
        self, env: Env, task_index: Optional[int] = None, max_num_steps: int = 30
    ) -> AgentRunResult:
        agent_cost = 0.0
        env_reset_res = env.reset(task_index=task_index)
        obs_user = env_reset_res.observation
        env_info = env_reset_res.info.model_dump()
        reward = 0.0
        messages: List[Dict[str, Any]] = [
            {"role": "system", "content": self.instruction},
            {"role": "user", "content": obs_user},
        ]
        done = False
        for step in range(1, max_num_steps + 1):
            next_message, action, done, cost = get_action(model = self.model, 
                                                          messages = messages, 
                                                          temperature =self.temperature, 
                                                          api_base =self.api_base,
                                                          tools = self.tools_info)
            agent_cost += cost
            env_response = env.step(action)
            reward = env_response.reward
            env_info = {**env_info, **env_response.info.model_dump()}
            if action.name != 'respond':
                next_message["tool_calls"] = next_message["tool_calls"][:1]
                messages.extend(
                    [
                        next_message,
                        {
                            "role": "tool",
                            "tool_call_id": next_message["tool_calls"][0]["id"],
                            "name": next_message["tool_calls"][0]["function"]["name"],
                            "content": env_response.observation,
                        },
                    ]
                )
            else:
                messages.extend(
                    [
                        next_message,
                        {"role": "user", "content": env_response.observation},
                    ]
                )
            if done or env_response.done:
                break

        return AgentRunResult(
            reward=reward,
            messages=messages,
            agent_cost=round(agent_cost, 8),
            info=env_info
        )
