"""
MemAgentQA Wrapper for MemAgent
Provides unified interface for evaluation pipeline using MemAgent's state memory mechanisms.
"""
import json
import asyncio
from typing import List, Dict, Any, Optional
from openai import AsyncOpenAI
from .agent import MemAgent


class MemAgentQA:
    """
    Unified evaluation interface for MemAgent baseline.
    Wraps MemAgent's memory construction and retrieval pipeline.
    """

    def __init__(
        self,
        num_questions: int,
        api_key: Optional[str] = None,
        model: Optional[str] = None,
        base_url: Optional[str] = None,
        retrieval_mode: str = "none",
        db_path: str = ":memory:",
        enable_thinking: Optional[bool] = None,
        temperature: float = 0.7,
        max_tokens: int = 1024,
        enable_tools: bool = True,
        log_path: Optional[str] = None,
        num_parallel_tasks: int = 1,
        enable_state_memory_summary: bool = True,
        embedding_model: Optional[Any] = None
    ):
        """
        Initialize MemAgentQA wrapper.

        Args:
            num_questions: Total number of questions for this trajectory
            api_key: OpenAI API key (or "EMPTY" for vLLM)
            model: LLM model name
            base_url: Custom OpenAI base URL (for vLLM)
            retrieval_mode: 'none' (default), 'bm25', 'llm', or 'embed'
            db_path: Database path (not used in current implementation)
            enable_thinking: Enable deep thinking mode (not used currently)
            temperature: Temperature for LLM sampling
            max_tokens: Maximum tokens for LLM response (default: 1024)
            enable_tools: Enable tool/function calling support
            log_path: Path to log file for this trajectory
            num_parallel_tasks: Number of parallel tasks for resource allocation
            enable_state_memory_summary: If True (default), build state memory summary; if False, only do embedding
            embedding_model: Optional SentenceTransformer model for chunk embeddings
        """
        self.num_questions = num_questions
        self.model = model
        self.api_key = api_key or "EMPTY"
        self.base_url = base_url
        self.retrieval_mode = retrieval_mode
        self.db_path = db_path
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.enable_tools = enable_tools
        self.log_path = log_path
        self.num_parallel_tasks = num_parallel_tasks
        self.enable_state_memory_summary = enable_state_memory_summary
        self.embedding_model = embedding_model

        # Initialize OpenAI client for vLLM with increased timeout and connection limits
        import httpx
        self.client = AsyncOpenAI(
            base_url=base_url, 
            api_key=self.api_key,
            timeout=httpx.Timeout(600.0, connect=60.0),  # 10 min total, 1 min connect
            max_retries=3,
            http_client=httpx.AsyncClient(
                limits=httpx.Limits(
                    max_connections=100,  # Increase connection pool
                    max_keepalive_connections=50
                )
            )
        )

        # Initialize Ray workers for CPU-based code execution
        from .utils import get_ray_worker_cls
        import os

        # Calculate total workers with upper bound to prevent OOM
        # Limit to min(num_questions, num_parallel_tasks, 8) to avoid excessive resource usage
        base_workers = max(num_questions, num_parallel_tasks)
        total_workers = min(base_workers, 8)  # Cap at 8 workers maximum

        # Get Ray worker class with proper CPU allocation
        # Each worker gets: (total_cpus * 0.6) / total_workers CPUs
        RayWorker = get_ray_worker_cls(num_workers=total_workers)

        # Create worker pool
        self.ray_workers = [RayWorker.remote(i) for i in range(total_workers)]
        self.worker_idx = 0

        print(f"\n[MemAgentQA] Initialized Ray workers: {total_workers} workers (capped at 8, requested: {base_workers})")
        print(f"[MemAgentQA] Resource allocation: {os.cpu_count() * 0.6 / total_workers:.2f} CPUs per worker")

        # MemAgent instance
        self.mem_agent = None

        print(f"[MemAgentQA] Initialized for {num_questions} questions")
        print(f"[MemAgentQA] Model: {model}")
        print(f"[MemAgentQA] Base URL: {base_url}")
        print(f"[MemAgentQA] Retrieval mode: {retrieval_mode}")
        print(f"[MemAgentQA] Tool support: {enable_tools}")

    async def call_llm_with_tools(
        self,
        prompt_text: str,
        tools: Optional[List[Dict[str, Any]]] = None,
        tool_choice: Optional[str] = None,
        max_tool_rounds: int = 5,
        trajectory_text_json: Optional[str] = None
    ):
        """
        Call LLM with multi-turn tool/function calling support (Hermes format).
        
        Implements the standard OpenAI function calling flow:
        1. Send user message with tools
        2. If model calls tools, execute them
        3. Add tool results to message history
        4. Continue conversation until model stops calling tools or max rounds reached
        
        Args:
            prompt_text: The prompt to send to the LLM
            tools: Optional list of tool definitions in OpenAI format
            tool_choice: Optional tool choice strategy ("auto", "none", or specific tool)
            max_tool_rounds: Maximum number of tool calling rounds (default: 5)
            trajectory_text_json: Optional trajectory data for tool execution
            
        Returns:
            Tuple of (output_dpr, response) where output_dpr contains all tool calls and results
        """
        from .tool import execute_tool_call
        
        # Initialize message history (Hermes format)
        messages = [{"role": "user", "content": prompt_text}]

        # Track all tool calls and results
        all_tool_calls = []
        all_tool_results = []

        # Maximum messages to keep in history to prevent OOM
        MAX_MESSAGE_HISTORY = 20

        # Multi-turn tool calling loop
        for round_num in range(max_tool_rounds):
            # Build request parameters
            request_params = {
                "model": self.model,
                "messages": messages,
                "temperature": self.temperature,
                "max_tokens": self.max_tokens
            }

            # Add tool parameters if provided and enabled
            if self.enable_tools and tools:
                request_params["tools"] = tools
                if tool_choice:
                    request_params["tool_choice"] = tool_choice

            try:
                response = await self.client.chat.completions.create(**request_params)
               
            except Exception as e:
                error_msg = str(e)
                # Check if it's a token limit error
                if "max_tokens" in error_msg or "max_completion_tokens" in error_msg:
                    # Extract available tokens from error message if possible
                    import re
                    match = re.search(r'(\d+)\s*>\s*(\d+)\s*-\s*(\d+)', error_msg)
                    if match:
                        input_tokens = int(match.group(3))
                        max_context = int(match.group(2))
                        available_tokens = max_context - input_tokens
                        # Use 80% of available tokens to be safe
                        safe_max_tokens = max(64, int(available_tokens * 0.8))
                        request_params["max_tokens"] = safe_max_tokens
                        
                        # Retry with adjusted max_tokens
                        try:
                            response = await self.client.chat.completions.create(**request_params)
                           
                        except Exception as e2:
                            print(f"[Client] Error in round {round_num} (retry failed): {e2}")
                            import traceback
                            traceback.print_exc()
                            break
                    else:
                        print(f"[Client] Error in round {round_num}: {e}")
                        import traceback
                        traceback.print_exc()
                        break
                else:
                    print(f"[Client] Error in round {round_num}: {e}")
                    import traceback
                    traceback.print_exc()
                    break
            
            # Extract response
            choice = response.choices[0]
            message = choice.message
            
            # Add assistant message to history
            assistant_msg = {"role": "assistant", "content": message.content}
            
            # Check for tool calls
            if hasattr(message, 'tool_calls') and message.tool_calls:
                # Add tool_calls to assistant message (Hermes format)
                assistant_msg["tool_calls"] = [
                    {
                        "id": tc.id,
                        "type": tc.type,
                        "function": {
                            "name": tc.function.name,
                            "arguments": tc.function.arguments
                        }
                    }
                    for tc in message.tool_calls
                ]
                messages.append(assistant_msg)
                
                print(f"[Client] Round {round_num + 1}: Model called {len(message.tool_calls)} tool(s)")
                
                # Execute each tool call
                for tool_call in message.tool_calls:
                    func_name = tool_call.function.name
                    func_args_str = tool_call.function.arguments
                    tool_call_id = tool_call.id
                    
                    # Parse arguments
                    try:
                        func_args = json.loads(func_args_str)
                    except json.JSONDecodeError as e:
                        print(f"[Client] Failed to parse tool arguments: {func_args_str}")
                        result = json.dumps({"error": f"Invalid JSON arguments: {str(e)}"})
                    else:
                        # Execute tool
                        print(f"[Client] Executing: {func_name}({func_args})")
                        
                        if trajectory_text_json:
                            result = execute_tool_call(func_name, func_args, trajectory_text_json)
                        else:
                            result = json.dumps({"error": "No trajectory data provided"})
                        
                        print(f"[Client] Tool result: {result[:50]}...")
                    
                    # Store tool call and result
                    all_tool_calls.append({
                        "id": tool_call_id,
                        "function": {"name": func_name, "arguments": func_args_str}
                    })
                    all_tool_results.append({
                        "tool_call_id": tool_call_id,
                        "function_name": func_name,
                        "result": result
                    })
                    
                    # Add tool result to message history (Hermes format)
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call_id,
                        "name": func_name,
                        "content": result
                    })

                # Limit message history to prevent OOM
                # Keep first message (user prompt) and most recent messages
                if len(messages) > MAX_MESSAGE_HISTORY:
                    # Keep first message and last (MAX_MESSAGE_HISTORY - 1) messages
                    messages = [messages[0]] + messages[-(MAX_MESSAGE_HISTORY - 1):]
                    print(f"[Client] Trimmed message history to {len(messages)} messages")

                # Continue to next round to let model process tool results
                continue
            else:
                # No tool calls, conversation is complete
                messages.append(assistant_msg)
                response_text = message.content or ""
                
                # Prepare output_dpr with all tool calls and results
                # NOTE: Don't store messages to prevent OOM (can be very large)
                output_dpr = None
                if all_tool_calls:
                    output_dpr = {
                        'tool_calls': all_tool_calls,
                        'tool_results': all_tool_results,
                        'num_rounds': round_num + 1,
                        # 'messages': messages  # Removed to prevent OOM
                    }
                
                print(f"[Client] Completed after {round_num + 1} round(s)")
                return output_dpr, response_text
        
        # Max rounds reached
        print(f"[Client] Max tool rounds ({max_tool_rounds}) reached")

        # Return last message content
        final_message = messages[-1] if messages else {}
        response_text = final_message.get("content", "")

        # Don't store messages to prevent OOM
        output_dpr = {
            'tool_calls': all_tool_calls,
            'tool_results': all_tool_results,
            'num_rounds': max_tool_rounds,
            'messages': messages,
            'truncated': True
        }
        
        return output_dpr, response_text

    async def build_memory(
        self,
        trajectory: List[Dict[str, Any]],
        task_description: str = ""
    ) -> None:
        """
        Build GLOBAL memory that applies to all questions.

        This method:
        1. Creates a MemAgent instance
        2. Constructs state memory from trajectory using construct_state_memory
        3. Replicates the SAME memory context to all questions

        Args:
            trajectory: List of trajectory steps with turn_idx, action, observation
            task_description: Description of the task/scenario
        """
        print(f"\n[MemAgentQA] Building GLOBAL memory")
        print(f"[MemAgentQA] Task: {task_description[:100] if task_description else 'No description'}...")
        print(f"[MemAgentQA] Trajectory length: {len(trajectory)}")

        # Create MemAgent instance with client access
        self.mem_agent = MemAgent(
            client=self.client,
            model=self.model,
            temperature=self.temperature,
            max_tokens=self.max_tokens,
            call_llm_func=self.call_llm_with_tools,
            db_path=self.db_path,
            log_path=self.log_path,
            ray_workers=self.ray_workers,
            enable_state_memory_summary=self.enable_state_memory_summary,
            embedding_model=self.embedding_model
        )
        self.mem_agent.retrieval_mode = self.retrieval_mode

        # Construct state memory from trajectory
        await self.mem_agent.construct_state_memory(
            trajectory=trajectory,
            task=task_description,
            episode_id="global"
        )
        self.memory= self.mem_agent.state_mem


    async def build_memory_w_question(
        self,
        trajectory: List[Dict[str, Any]],
        question: str,
        question_idx: int,
        task_description: str = ""
    ) -> None:
        """
        Build QUESTION-AWARE (local) memory for a single question.

        This method:
        1. On first call (question_idx==0): Constructs state memory from trajectory
        2. For each question: Retrieves relevant memory segments using retrieve_with_query
        3. Stores question-specific memory in memory_list[question_idx]

        Args:
            trajectory: List of trajectory steps
            question: Single question to answer
            question_idx: Index of this question (0-based)
            task_description: Description of the task/scenario
        """
        # Build memory only once (on first question)
        raise NotImplementedError("Question-aware memory building not implemented in this version.")

    async def answer_question(self, questions: List[str]) -> List[str]:
        """
        Generate answers by directly calling MemAgent's answer method.

        Args:
            questions: List of questions to answer

        Returns:
            List of answer strings in the same order as questions
        """
        print(f"\n[MemAgentQA] Generating answers for {len(questions)} questions")

        answers = []

        for i, question in enumerate(questions):
            print(f"\n[MemAgentQA] Q{i+1}/{len(questions)}: {question}")

            # Directly call MemAgent's answer method
            answer = await self.mem_agent.answer(question)

            answers.append(answer)
            print(f"[MemAgentQA] A{i+1}: {answer[:200]}...")

        print(f"\n[MemAgentQA] Completed answering all {len(questions)} questions")
        return answers

    def cleanup(self):
        """Clean up Ray workers and MemAgent resources to prevent OOM"""
        try:
            import gc

            # Clean up MemAgent instance
            if hasattr(self, 'mem_agent') and self.mem_agent is not None:
                print(f"[MemAgentQA] Cleaning up MemAgent instance...")
                self.mem_agent.cleanup()
                self.mem_agent = None

            # Shutdown Ray workers
            if hasattr(self, 'ray_workers') and self.ray_workers:
                print(f"[MemAgentQA] Shutting down {len(self.ray_workers)} Ray workers...")
                try:
                    import ray
                    for worker in self.ray_workers:
                        try:
                            ray.kill(worker)
                        except Exception as e:
                            print(f"[MemAgentQA] Warning: Failed to kill Ray worker: {e}")
                    self.ray_workers = []
                    print(f"[MemAgentQA] Ray workers shutdown complete")
                except Exception as e:
                    print(f"[MemAgentQA] Warning: Failed to shutdown Ray workers: {e}")

            # Clear GPU cache
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                print(f"[MemAgentQA] Cleared GPU cache")

            # Force garbage collection
            for _ in range(3):
                gc.collect()

            print(f"[MemAgentQA] Cleanup completed")
        except Exception as e:
            print(f"[MemAgentQA] Warning: Cleanup failed: {e}")


def create_mem_agent_qa(
    num_questions: int,
    model: str,
    tokenizer: Any,
    url: str,
    api_key: str,
    log_path: Optional[str] = None,
    num_parallel_tasks: int = 1,
    enable_state_memory_summary: bool = True,
    embedding_model: Optional[Any] = None
) -> MemAgentQA:
    """
    Create MemAgent QA instance compatible with memory_interface.py.

    This function provides the interface expected by the evaluation harness.

    Args:
        num_questions: Number of questions for this trajectory
        model: Model name (e.g., "Qwen2.5-7B-Instruct")
        tokenizer: Tokenizer instance (required for interface compatibility)
        url: API endpoint URL (e.g., "http://localhost:8001/v1")
        api_key: API key (e.g., "EMPTY" for vLLM)
        log_path: Path to log file for this trajectory
        num_parallel_tasks: Number of parallel tasks for resource allocation (default: 1)
        enable_state_memory_summary: If True (default), build state memory summary; if False, only do embedding
        embedding_model: Optional SentenceTransformer model for chunk embeddings

    Returns:
        MemAgentQA instance configured for the evaluation harness
    """
    return MemAgentQA(
        num_questions=num_questions,
        api_key=api_key,
        model=model,
        base_url=url,
        retrieval_mode="bm25",  # Default to BM25 retrieval
        db_path=":memory:",
        enable_thinking=False,  # vLLM doesn't support thinking mode
        temperature=0.7,
        max_tokens=2048,
        enable_tools=True,  # Enable tool/function calling
        log_path=log_path,
        num_parallel_tasks=num_parallel_tasks,
        enable_state_memory_summary=enable_state_memory_summary,
        embedding_model=embedding_model
    )

