import requests
import logging
from typing import List, Dict, Any, Optional
import json
import re

VDB_SERVER_URL = "http://localhost:8011"
logging.basicConfig(level=logging.WARNING,
                    format='%(asctime)s - %(levelname)s - %(message)s')


class ActionClient:
    """
    Client for interacting with the backend RAG search service.
    Exposes a single unified 'execute_search_plan' method that directly maps to Agent's Function Call definition.
    """

    def __init__(self, server_base_url: str):
        self.base_url = server_base_url.rstrip('/')
        self.session = requests.Session()

    def _format_results_for_llm(self, response_data: Dict[str, Any]) -> str:
        """
        Format structured data received from the server into LLM-friendly string.
        """
        if not response_data:
            return "No information received from the search service."

        formatted_string = ""

        search_results = response_data.get('search_results', [])
        if search_results:
            formatted_string += "[Relevant Doc Chunks]\n"
            for i, chunk in enumerate(search_results):
                formatted_string += f"[Chunk {i+1}]\n"
                formatted_string += f"Title: {chunk.get('title', 'N/A')}\n"
                formatted_string += f"Source Doc ID: {chunk.get('doc_id', 'N/A')}\n"
                formatted_string += f"Semantic Score: {chunk.get('semantic_score', 0.0):.4f}\n"
                if chunk.get('bm25_score') is not None and chunk.get('bm25_score') >= 0:
                    formatted_string += f"Exact Score: {chunk.get('bm25_score', 0.0):.4f}\n"
                formatted_string += f"Content: {chunk.get('content', 'N/A').strip()}\n\n"

        else:
            formatted_string += "[Relevant Documents]\nNo relevant documents found based on the search criteria.\n\n"

        entity_snippets = response_data.get('entity_snippets')
        if entity_snippets:
            formatted_string += "[Extra Entity Snippets]\n"
            for i, snippet in enumerate(entity_snippets):
                formatted_string += f"[Snippet {i+1}]\n"
                formatted_string += f"Title: {snippet.get('title', 'N/A')}\n"
                formatted_string += f"Snippet: {snippet.get('snippet', 'N/A').strip()}\n\n"

        return formatted_string.strip()

    def execute_search_plan(
        self,
        semantic_query: str,
        bm25_query_keywords: Optional[List[str]] = None,
        bm25_weight: float = 0.3,
        entity_match: Optional[str] = None,
        include_doc_ids: Optional[List[str]] = None,
        exclude_doc_ids: Optional[List[str]] = None,
        top_k: int = 3
    ) -> str:
        """
        Execute a comprehensive search plan and return formatted context for LLM use.
        Maps directly to the server's /execute_search endpoint.
        """
        logging.info(
            f"--- Executing search plan for query: '{semantic_query}' ---")

        safe_bm25_keywords = [re.sub(r'[^\w\s]', '', keyword)
                              for keyword in bm25_query_keywords] if bm25_query_keywords else None
        safe_bm25_weight = max(
            0.0, min(float(bm25_weight), 1.0)) if safe_bm25_keywords else 0
        safe_entity = re.sub(
            r'[^\w\s]', '', entity_match) if entity_match else None
        safe_top_k = max(3, min(int(top_k), 5))

        payload = {
            "semantic_query": semantic_query,
            "bm25_query_keywords": safe_bm25_keywords,
            "bm25_weight": safe_bm25_weight,
            "entity_match": safe_entity,
            "include_doc_ids": include_doc_ids,
            "exclude_doc_ids": exclude_doc_ids,
            "top_k": safe_top_k
        }
        payload = {k: v for k, v in payload.items() if v is not None}

        logging.info(
            f"Sending payload to server: {json.dumps(payload, indent=2)}")

        try:
            response = self.session.post(
                f"{self.base_url}/execute_search", json=payload)
            response.raise_for_status()

            response_data = response.json()
            logging.info(
                f"Received {len(response_data.get('search_results', []))} documents and {len(response_data.get('entity_snippets') or [])} snippets.")

            return self._format_results_for_llm(response_data)

        except requests.exceptions.RequestException as e:
            error_message = f"Error calling RAG service: {e}"
            if e.response is not None:
                try:
                    error_detail = e.response.json().get('detail', e.response.text)
                    error_message += f"\nServer Response: {error_detail}"
                except json.JSONDecodeError:
                    error_message += f"\nServer Response (non-JSON): {e.response.text}"
            logging.error(error_message)
            return f"[Service Error]\n{error_message}"
        except Exception as e:
            logging.error(f"An unexpected error occurred in the client: {e}")
            return f"[Client Error]\nAn unexpected error occurred: {e}"


if __name__ == '__main__':
    client = ActionClient(server_base_url=VDB_SERVER_URL)

    print("\n\n--- EXAMPLE: Search with Entity Match ---")
    result2 = client.execute_search_plan(
        semantic_query="Tell me about the movie directed by Christopher Nolan about dreams.",
        entity_match="Inception",
        bm25_query_keywords=["Christopher Nolan", "dream movie"],
        bm25_weight=0.3,
        top_k=2
    )
    print(result2)
