# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Adapted for custom InteractTool by the user.

import json
import logging
import os
import threading
import traceback
from contextlib import ExitStack
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar
from enum import Enum
from uuid import uuid4
import math

import numpy as np
import ray
import ray.actor
import requests

from verl.tools.base_tool import BaseTool
from verl.tools.schemas import OpenAIFunctionToolSchema

logger = logging.getLogger(__name__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


T = TypeVar("T")


# Adapted from verl/tools/sandbox_fusion_tools.py
class PoolMode(Enum):
    """Execution pool mode enumeration."""

    ThreadMode = 1
    ProcessMode = 2


class ActionClient:
    """Client to interact with the backend vector query service."""

    def __init__(self, server_base_url: str):
        self.base_url = server_base_url.rstrip('/')
        self.session = requests.Session()
        self.n_results_default = 3

    def _search_query(self,
                      query: str,
                      n_results: int,
                      include_domain_keywords: Optional[List[str]] = None,
                      exclude_domain_keywords: Optional[List[str]] = None,
                      alpha: Optional[float] = None
                      ) -> List[Dict[str, Any]]:
        try:
            json_payload = {
                "query_text": query,
                "n_results": n_results
            }
            if include_domain_keywords:
                json_payload["include_domain_keywords"] = include_domain_keywords
            if exclude_domain_keywords:
                json_payload["exclude_domain_keywords"] = exclude_domain_keywords
            if alpha is not None:
                json_payload["rerank_alpha"] = alpha

            response = self.session.post(
                f"{self.base_url}/search_query", json=json_payload)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error calling /search_query: {e}")
            return []

    def _search_query_with_ids(self, query: str, id_list: List[str], n_results: int) -> List[Dict[str, Any]]:
        try:
            response = self.session.post(f"{self.base_url}/search_query_with_ids", json={
                                         "query_text": query, "id_list": id_list, "n_results": n_results})
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error calling /search_query_with_ids: {e}")
            return []

    def _search_query_with_titles(self, query: str, title_text: str, title_num: int = 10, n_per_title: int = 2) -> List[Dict[str, Any]]:
        try:
            response = self.session.post(f"{self.base_url}/search_query_with_titles", json={
                                         "query_text": query, "title_text": title_text, "title_num": title_num, "n_per_title": n_per_title})
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            logger.error(f"Error calling /search_query_with_titles: {e}")
            return []

    def _format_chunks_for_llm(self, text_chunks: List[Dict[str, Any]]) -> str:
        if not text_chunks:
            return "No relevant documents found based on the provided criteria."
        formatted_string = ""
        for i, chunk in enumerate(text_chunks):
            formatted_string += f"[Chunk {i+1}]\n"
            formatted_string += f"title: {chunk.get('title', 'N/A')}\n"
            formatted_string += f"doc_id: {chunk.get('doc_id', 'N/A')}\n"
            metadata = chunk.get("metadata", {})
            keywords = metadata.get("keywords", [])
            formatted_string += f"domain_keywords: {', '.join(keywords)}\n"
            formatted_string += f"content: {chunk.get('content', 'N/A').strip()}\n\n"
        return formatted_string

    def execute_search_plan(self, **kwargs) -> str:
        query = kwargs.get("query")
        if not query:
            return json.dumps({"error": "Query parameter is missing."})

        try:
            retrieval_scale = kwargs.get('retrieval_scale', 'moderate')
            include_doc_ids = kwargs.get('include_doc_ids', None)
            exclude_doc_ids = kwargs.get('exclude_doc_ids', None)
            match_title = kwargs.get('match_title', None)
            include_domain_keywords = kwargs.get(
                'include_domain_keywords', None)
            exclude_domain_keywords = kwargs.get(
                'exclude_domain_keywords', None)
            alpha = 0.1

            n_results = 5 if retrieval_scale == 'large' else self.n_results_default

            prefer_chunks = []
            prefer_chunk_ids = set()

            if include_doc_ids:
                logging.info(
                    f"Fetching chunks specifically from documents: {include_doc_ids}")
                doc_match_chunks = self._search_query_with_ids(
                    query, include_doc_ids, n_results=n_results)
                for chunk in doc_match_chunks:
                    chunk_id = f"{chunk.get('doc_id')}_{chunk.get('metadata', {}).get('chunk_index_in_doc')}"
                    if chunk_id not in prefer_chunk_ids:
                        prefer_chunks.append(chunk)
                        prefer_chunk_ids.add(chunk_id)

            match_title_message = ""
            if match_title:
                logging.info(
                    f"Fetching chunks by matching title: '{match_title}'")
                title_match_chunks = self._search_query_with_titles(
                    query, match_title)

                if len(title_match_chunks) == 0:
                    match_title_message = f"No documents found matching the title '{match_title}'."

                title_match_chunks = [
                    c for c in title_match_chunks if c.get('score', 0.0) >= 0.8]

                title_match_chunks.sort(
                    key=lambda x: x.get('score', 0.0), reverse=True)

                limit = n_results
                added_count = 0
                for chunk in title_match_chunks:
                    if added_count >= limit:
                        break
                    chunk_id = f"{chunk.get('doc_id')}_{chunk.get('metadata', {}).get('chunk_index_in_doc')}"
                    if chunk_id not in prefer_chunk_ids:
                        insert_index = 2 * added_count + 1
                        if insert_index > len(prefer_chunks):
                            prefer_chunks.append(chunk)
                        else:
                            prefer_chunks.insert(insert_index, chunk)
                        prefer_chunk_ids.add(chunk_id)
                        added_count += 1
                logging.info(
                    f"Added {added_count} chunks from title match after filtering and deduplication.")

            logging.info(
                f"Executing main search for a pool of {n_results} chunks.")
            candidate_chunks = self._search_query(
                query,
                n_results=n_results,
                include_domain_keywords=include_domain_keywords,
                exclude_domain_keywords=exclude_domain_keywords
            )

            candidate_chunks = [
                c for c in candidate_chunks if f"{c.get('doc_id')}_{c.get('metadata', {}).get('chunk_index_in_doc')}" not in prefer_chunk_ids]

            if exclude_doc_ids:
                all_doc_ids = set(
                    c.get('doc_id') for c in candidate_chunks + prefer_chunks)
                need_to_exclude = set(exclude_doc_ids) & all_doc_ids
                candidate_chunks = [c for c in candidate_chunks if c.get(
                    'doc_id') not in exclude_doc_ids]
                prefer_chunks = [c for c in prefer_chunks if c.get(
                    'doc_id') not in exclude_doc_ids]

                logging.info(
                    f"Applied exclusion filter for docs: {exclude_doc_ids}. filtered nums: {len(need_to_exclude)}")

            max_total_chunks = math.ceil(1.5 * n_results)
            min_query_search_chunks_count = math.ceil(max_total_chunks / 2)
            if min_query_search_chunks_count + len(prefer_chunks) > max_total_chunks:
                prefer_count_to_take = max_total_chunks - min_query_search_chunks_count
                query_search_count = min_query_search_chunks_count
            else:
                prefer_count_to_take = len(prefer_chunks)
                query_search_count = min(
                    len(candidate_chunks), max_total_chunks - prefer_count_to_take, n_results)

            final_results = prefer_chunks[:prefer_count_to_take] + \
                candidate_chunks[:query_search_count]

            logging.info(
                f"--- Execution finished. Returning {len(final_results)} results. ---")

            final_text = self._format_chunks_for_llm(final_results)

            if match_title_message:
                final_text = "[Info]: " + \
                    match_title_message + "\n\n" + final_text

            return final_text.strip()

        except Exception as e:
            logger.error(
                f"Error during search plan execution: {e}\n{traceback.format_exc()}")
            return json.dumps({"error": f"An internal error occurred during search plan execution: {str(e)}"})

# --- Ray concurrent execution related classes (adapted from Verl examples) ---


@ray.remote(concurrency_groups={"acquire": 1, "release": 10})
class TokenBucketWorker:
    """Ray actor for rate limiting using token bucket algorithm."""

    def __init__(self, rate_limit: int):
        self.rate_limit = rate_limit
        self.current_count = 0
        self._semaphore = threading.Semaphore(rate_limit)

    @ray.method(concurrency_group="acquire")
    def acquire(self):
        """Acquire a token from the bucket."""
        self._semaphore.acquire()
        self.current_count += 1

    @ray.method(concurrency_group="release")
    def release(self):
        """Release a token back to the bucket."""
        self.current_count -= 1
        self._semaphore.release()

    def get_current_count(self):
        """Get current number of acquired tokens."""
        return self.current_count


class InteractExecutionWorker:
    """Worker for executing ANY tool function with optional rate limiting."""

    def __init__(self, enable_global_rate_limit=True, rate_limit=10):
        self.rate_limit_worker = self._init_rate_limit(
            rate_limit) if enable_global_rate_limit else None

    def _init_rate_limit(self, rate_limit):
        return TokenBucketWorker.options(name="global-rate-limiter", get_if_exists=True).remote(rate_limit)

    def ping(self):
        """Health check method."""
        return True

    def execute(self, fn: Callable[..., Any], *fn_args, **fn_kwargs) -> Any:
        """Execute a given function `fn` with rate limiting."""
        if self.rate_limit_worker:
            with ExitStack() as stack:
                stack.callback(self.rate_limit_worker.release.remote)
                ray.get(self.rate_limit_worker.acquire.remote())
                return fn(*fn_args, **fn_kwargs)
        else:
            return fn(*fn_args, **fn_kwargs)


def init_interact_execution_pool(num_workers: int, enable_global_rate_limit=True, rate_limit=10, mode: PoolMode = PoolMode.ThreadMode):
    """Initialize search execution pool."""
    if mode == PoolMode.ThreadMode:
        return ray.remote(InteractExecutionWorker).options(max_concurrency=num_workers).remote(enable_global_rate_limit=enable_global_rate_limit, rate_limit=rate_limit)
    else:
        raise NotImplementedError("Process mode is not implemented yet")


class InteractTool(BaseTool):

    def __init__(self, config: dict, tool_schema: OpenAIFunctionToolSchema):
        super().__init__(config, tool_schema)
        self._instance_dict = {}

        self.action_client = ActionClient(
            server_base_url="xxx"
        )

        self.num_workers = config.get("num_workers", 20)
        self.rate_limit = config.get("rate_limit", 20)
        self.enable_global_rate_limit = config.get(
            "enable_global_rate_limit", True)

        self.execution_pool = init_interact_execution_pool(
            num_workers=self.num_workers,
            enable_global_rate_limit=self.enable_global_rate_limit,
            rate_limit=self.rate_limit
        )
        logger.info(
            f"Initialized InteractTool with a concurrent worker (max_concurrency={self.num_workers}).")

    def get_openai_tool_schema(self) -> OpenAIFunctionToolSchema:
        return self.tool_schema

    async def create(self, instance_id: Optional[str] = None, **kwargs) -> str:
        if instance_id is None:
            instance_id = str(uuid4())
        self._instance_dict[instance_id] = {"response": "", "reward": []}
        return instance_id

    def _execute_search_plan_wrapper(self, instance_id: str, parameters: dict) -> Tuple[str, dict]:
        """
        A wrapper that calls the core logic and formats the output into (result, metrics).
        This method is what gets executed by the remote Ray worker.
        """
        logger.info(
            f"Executing search plan for instance {instance_id} with query: '{parameters.get('query')}'")
        result_text = self.action_client.execute_search_plan(**parameters)

        metadata = {
            "query": parameters.get("query"),
            "char_length": len(result_text),
            "retrieval_scale": parameters.get('retrieval_scale', 'moderate')
        }
        return result_text, metadata

    async def execute(self, instance_id: str, parameters: dict[str, Any], **kwargs) -> Tuple[str, float, dict]:
        """Execute the tool.

        Args:
            instance_id: The instance ID of the tool
            parameters: Tool parameters
        """
        if not parameters.get("query"):
            error_msg = "Error: 'query' parameter is required and cannot be empty."
            logger.error(
                f"[InteractTool] {error_msg} Received parameters: {parameters}")
            return json.dumps({"result": error_msg}), 0.0, {}

        try:
            tool_response, tool_metrics = await self.execution_pool.execute.remote(
                self._execute_search_plan_wrapper,
                instance_id,
                parameters
            )

            self._instance_dict[instance_id]["reward"].append(
                tool_response.strip())

            tool_reward_score = 0.0

            return tool_response, tool_reward_score, tool_metrics

        except Exception as e:
            error_result = json.dumps(
                {"result": f"Tool execution failed: {e}"})
            logger.error(
                f"[InteractTool] Execution failed: {e}\n{traceback.format_exc()}")
            return error_result, 0.0, {"error": str(e)}

    async def calc_reward(self, instance_id: str, **kwargs) -> str:
        return self._instance_dict[instance_id]["reward"]

    async def release(self, instance_id: str, **kwargs) -> None:
        if instance_id in self._instance_dict:
            del self._instance_dict[instance_id]
