"""
Memory Agent for handling long trajectories with structured state memory
"""
import re
from typing import Dict, List, Any
from .utils import retrieve_with_bm25
from .retrieve import (
    retrieve_with_query,
    build_ordered_trajectory_from_turns
)
from .construct import construct_state_memory


class MemAgent:
    """
    General-purpose Memory Agent that converts long trajectories into retrievable structured state memory.

    Uses LLM to maintain memory:
    - state_mem: Compressed state memory
    - trajectory_data: Original trajectory with task and episode_id
    - embed_mem: Embedded memory from trajectory
    """

    def __init__(self, client=None, model: str = None, temperature: float = 0.7, max_tokens: int = 8192, call_llm_func=None, db_path: str = ":memory:", embed_engine=None, chunk: int = 8192, log_path: str = None, tool_mode_costimize: bool = True, ray_workers: List = None, enable_state_memory_summary: bool = True, max_iter: int = 3, embedding_model=None):
        """
        Initialize MemAgent with client configuration.

        Args:
            client: AsyncOpenAI client instance
            model: Model name
            temperature: Temperature for sampling
            max_tokens: Maximum tokens for response
            call_llm_func: Async function for LLM interaction with tool support
            db_path: Not used in current implementation (for future extensions)
            embed_engine: Optional function for embedding text
            chunk: Size of each chunk for processing long trajectories (default: 4096)
            log_path: Path to log file for this trajectory
            tool_mode_costimize: If True, use customized code generation mode for retrieval; if False, use predefined tools
            ray_workers: List of Ray worker actors for CPU-based code execution
            enable_state_memory_summary: If True (default), build state memory summary; if False, only do embedding
            max_iter: Maximum number of tool call iterations (default: 3)
            embedding_model: Optional SentenceTransformer model for chunk embeddings
        """
        self.client = client
        self.model = model
        self.temperature = temperature
        self.max_tokens = max_tokens
        self.call_llm_func = call_llm_func
        self.embed_engine = embed_engine
        self.chunk = chunk
        self.trajectory = None
        self.task = None
        self.episode_id = None
        self.num_turns = 0

        self.state_mem = None
        self.trajectory_data = None
        self.embed_mem = None

        self.retrieval_mode = "none"  # Default retrieval mode
        self.tool_mode_costimize = tool_mode_costimize  # Tool mode: True for code generation, False for predefined tools
        self.log_path = log_path
        self.ray_workers = ray_workers or []
        self.worker_idx_container = [0]  # Mutable container for worker round-robin
        self.enable_state_memory_summary = enable_state_memory_summary  # Control state memory summary generation
        self.max_iter = max_iter  # Maximum tool call iterations
        self.embedding_model = embedding_model  # SentenceTransformer model for chunk embeddings

        # Storage for chunk-based retrieval artifacts
        self.chunk_graph = None
        self.chroma_client = None
        self.collection = None
        self.chroma_dir = None

        # Thread-safe logging
        import threading
        self._log_lock = threading.Lock()

        # Clear log file if it exists (overwrite mode)
        if self.log_path:
            with open(self.log_path, 'w', encoding='utf-8') as f:
                f.write("")

    def _log(self, message: str):
        """Log a message with timestamp to file (thread-safe).

        Args:
            message: The key information to log for this stage
        """
        if not self.log_path:
            return

        from datetime import datetime
        time_str = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_line = f"[{time_str}] {message}\n"

        with self._log_lock:
            with open(self.log_path, 'a', encoding='utf-8') as f:
                f.write(log_line)

    async def construct_state_memory(self, trajectory: List[Dict[str, Any]],
                                task: str = "", episode_id: str = ""):
        """
        Convert a long trajectory into state memory and trajectory data.

        Args:
            trajectory: List of trajectory turns with turn_idx, action, observation
            task: Task description
            episode_id: Episode identifier
        """
        self.trajectory = trajectory
        self.task = task
        self.episode_id = episode_id
        # Call the construct_state_memory function from construct.py
        result = await construct_state_memory(
            trajectory=trajectory,
            task=task,
            episode_id=episode_id,
            chunk_size=self.chunk,
            enable_state_memory_summary=self.enable_state_memory_summary,
            call_llm_func=self.call_llm_func,
            embedding_model=self.embedding_model
        )
        # Update instance variables
        self.state_mem = result['state_mem']
        self.trajectory_data = result['trajectory_data']
        self.num_turns = result['num_turns']

        # Store chunk-based retrieval artifacts
        self.chunk_graph = result.get('chunk_graph')
        self.chroma_client = result.get('chroma_client')
        self.collection = result.get('collection')
        self.chroma_dir = result.get('chroma_dir')

    async def retrieve_with_query(self, query: str, ) -> str:
        """
        Retrieve information from memory and generate answer.

        Args:
            query: Natural language query about the trajectory
            
        Returns:
            Natural language answer to the query
        """
        # Call the retrieve_with_query function from retrieve.py
        answer = await retrieve_with_query(
            query=query,
            keywords_info={},  # Not used in chunk-based retrieval
            relevant_turn_indices=[],  # Not used in chunk-based retrieval
            trajectory_data=self.trajectory_data,
            task=self.task,
            tool_mode_costimize=self.tool_mode_costimize,
            max_iter=self.max_iter,
            call_llm_func=self.call_llm_func,
            max_tokens=self.max_tokens,
            log_func=self._log,
            ray_workers=self.ray_workers,
            worker_idx_container=self.worker_idx_container,
            chunk_graph=self.chunk_graph,
            collection=self.collection,
            embedding_model=self.embedding_model
        )
        return answer
    async def answer(self, query: str) -> str:
        """
        Answer a query using chunk-based retrieval.

        Args:
            query: Natural language query

        Returns:
            Natural language answer
        """

        # Log question
        self._log(f"QUESTION: {query}")

        # Use state_mem as context (if available)
        state_mem_str = self.state_mem if self.state_mem else ""

        # Retrieve and generate answer (retrieve_with_query now returns the answer directly)
        try:
            answer = await self.retrieve_with_query(query)
            # Log answer
            self._log(f"ANSWER: {answer}")
            self._log("=" * 80)  # Separator for readability
            return answer
        except Exception as e:
            print(f"[MemAgent] Error during retrieval and answering: {e}")
            self._log(f"ERROR: {e}")
            import traceback
            traceback.print_exc()
            return "Unable to generate answer due to an error."

    def cleanup(self):
        """Clean up ChromaDB and GPU resources to free memory"""
        try:
            import gc
            import shutil
            import os

            # Delete ChromaDB collection
            if hasattr(self, 'collection') and self.collection is not None:
                try:
                    collection_name = self.collection.name
                    if hasattr(self, 'chroma_client') and self.chroma_client is not None:
                        self.chroma_client.delete_collection(name=collection_name)
                        print(f"[MemAgent] Deleted ChromaDB collection: {collection_name}")
                except Exception as e:
                    print(f"[MemAgent] Warning: Failed to delete collection: {e}")

            # Clean up ChromaDB directory
            if hasattr(self, 'chroma_dir') and self.chroma_dir and os.path.exists(self.chroma_dir):
                try:
                    shutil.rmtree(self.chroma_dir)
                    print(f"[MemAgent] Removed ChromaDB directory: {self.chroma_dir}")
                except Exception as e:
                    print(f"[MemAgent] Warning: Failed to remove ChromaDB directory: {e}")

            # Clear GPU cache if using CUDA
            import torch
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                print(f"[MemAgent] Cleared GPU cache")

            # Force garbage collection
            for _ in range(3):
                gc.collect()

            print(f"[MemAgent] Cleanup completed (ChromaDB storage)")
        except Exception as e:
            print(f"[MemAgent] Warning: Cleanup failed: {e}")

    def close(self):
        """Clean up resources."""
        self.cleanup()

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.close()
