"""
State memory construction functions for Memory Agent - Turn-wise embedding approach
"""
from typing import Dict, List, Any, Tuple
import uuid
import tempfile
import os
import torch
import gc
import chromadb
from .utils import extract_state_memory_from_response, truncate_trajectory_text
from .prompt import COMPRESS_PROMPT_TEMPLATE


def format_single_turn(turn: Dict[str, Any]) -> str:
    """Format single turn into readable text."""
    t = turn.get('turn_idx', 0)
    action = turn.get('action', '')
    observation = turn.get('observation', '')
    return f"Turn {t}:\n  Action: {action}\n  Observation: {observation}\n"


def format_trajectory_for_llm(trajectory: List[Dict[str, Any]], task: str) -> str:
    """Format trajectory into readable text for LLM with maximum detail preservation."""
    lines = []
    for turn in trajectory:
        t = turn.get('turn_idx', 0)
        action = turn.get('action', '')
        observation = turn.get('observation', '')
        lines.append(f"Turn {t}:")
        lines.append(f"  Action: {action}")
        lines.append(f"  Observation: {observation}")
    return "\n".join(lines)


def trajectory_to_documents_turnwise(trajectory: List[Dict[str, Any]]) -> List[str]:
    """
    Convert trajectory to a list of documents (one per turn) - turn-wise approach.
    Similar to embedding_qa.py's trajectory_to_documents().

    Args:
        trajectory: List of trajectory steps

    Returns:
        List of document strings, one per turn
    """
    documents = []
    for step in trajectory:
        turn_idx = step.get('turn_idx', 'N/A')
        action = step.get('action', 'N/A')
        observation = step.get('observation', 'N/A')

        doc = f"Turn {turn_idx}:\n  Action: {action}\n  Observation: {observation}"
        documents.append(doc)

    return documents


def split_trajectory_into_chunks(trajectory_text: str, chunk_size: int = 2000) -> List[Tuple[int, str, int, int, List[int], int]]:
    """
    Split trajectory text into chunks of approximately chunk_size characters.

    Args:
        trajectory_text: Full trajectory text
        chunk_size: Target size for each chunk (default: 2000)

    Returns:
        List of tuples (chunk_id, chunk_text, start_char, end_char, turn_indices, chunk_in_turn)
    """
    import re

    total_chars = len(trajectory_text)
    
    # First pass: find all turn boundaries
    turn_pattern = r'Turn (\d+):'
    turn_positions = []
    for match in re.finditer(turn_pattern, trajectory_text):
        turn_num = int(match.group(1))
        turn_positions.append((match.start(), turn_num))
    
    # Add end position for last turn
    turn_positions.append((total_chars, -1))
    
    chunks = []
    chunk_id = 0
    start_pos = 0

    while start_pos < total_chars:
        end_pos = min(start_pos + chunk_size, total_chars)
        chunk_text = trajectory_text[start_pos:end_pos]

        # Try to break at turn boundary if possible
        if end_pos < total_chars:
            # Look for the last "Turn X:" in the chunk
            turn_boundary_pattern = r'\nTurn \d+:'
            matches = list(re.finditer(turn_boundary_pattern, chunk_text))
            if matches and len(matches) > 1:  # Keep at least one turn
                # Break at the last turn boundary
                last_turn_pos = matches[-1].start()
                if last_turn_pos > chunk_size * 0.5:  # Only break if we've covered at least 50% of chunk
                    end_pos = start_pos + last_turn_pos
                    chunk_text = trajectory_text[start_pos:end_pos]

        # Determine which turns this chunk contains
        turn_indices = []
        for turn_pos, turn_num in turn_positions:
            if turn_num == -1:
                continue
            # Check if turn starts before or within this chunk
            if turn_pos < end_pos:
                # Find next turn position
                next_turn_pos = total_chars
                for next_pos, next_num in turn_positions:
                    if next_num > turn_num or next_num == -1:
                        next_turn_pos = next_pos
                        break
                # If turn overlaps with this chunk, include it
                if turn_pos < end_pos and start_pos < next_turn_pos:
                    if turn_num not in turn_indices:
                        turn_indices.append(turn_num)
        
        # Determine chunk position within turn sequence (for multi-chunk turns)
        chunk_in_turn = 0
        if turn_indices:
            # Count how many chunks we've seen for the first turn in this chunk
            first_turn = turn_indices[0]
            for prev_chunk_id, prev_text, _, _, prev_turns, _ in chunks:
                if prev_turns and first_turn in prev_turns:
                    chunk_in_turn += 1

        chunks.append((chunk_id, chunk_text.strip(), start_pos, end_pos, turn_indices, chunk_in_turn))
        chunk_id += 1
        start_pos = end_pos

    return chunks


def build_chunk_graph(chunks: List[Tuple[int, str, int, int, List[int], int]]) -> Dict[str, Any]:
    """
    Build a simple causality graph where chunks are connected sequentially.

    Args:
        chunks: List of (chunk_id, chunk_text, start_char, end_char, turn_indices, chunk_in_turn)

    Returns:
        Graph structure with nodes and edges
    """
    nodes = []
    edges = []

    for chunk_id, chunk_text, start_char, end_char, turn_indices, chunk_in_turn in chunks:
        nodes.append({
            'id': chunk_id,
            'text': chunk_text,
            'start_char': start_char,
            'end_char': end_char,
            'length': len(chunk_text),
            'turn_indices': turn_indices,
            'chunk_in_turn': chunk_in_turn,
            'order': chunk_id
        })

        # Add edge to next chunk (sequential/causal relationship)
        if chunk_id > 0:
            edges.append({
                'from': chunk_id - 1,
                'to': chunk_id,
                'type': 'sequential'
            })

    graph = {
        'nodes': nodes,
        'edges': edges,
        'num_chunks': len(chunks)
    }

    return graph


async def build_turn_embeddings(
    trajectory: List[Dict[str, Any]],
    embedding_model,
    chroma_client,
    collection_name: str
) -> chromadb.Collection:
    """
    Build embeddings for turns using ChromaDB - turn-wise approach (similar to embedding_qa.py).

    Args:
        trajectory: List of trajectory turns
        embedding_model: Sentence transformer model for encoding
        chroma_client: ChromaDB client
        collection_name: Name for the collection

    Returns:
        ChromaDB collection with turn embeddings
    """
    # Convert trajectory to documents (one per turn)
    documents = trajectory_to_documents_turnwise(trajectory)

    print(f"[Construct] Building embeddings for {len(documents)} turns using ChromaDB")

    # Create collection
    collection = chroma_client.get_or_create_collection(
        name=collection_name,
        metadata={"hnsw:space": "cosine"}
    )

    # Prepare data
    ids = [f"turn_{i}" for i in range(len(documents))]
    metadatas = [
        {
            'turn_idx': i,
            'type': 'trajectory_turn'
        }
        for i in range(len(documents))
    ]

    # Auto-detect device
    if hasattr(embedding_model, 'device'):
        device = str(embedding_model.device)
    elif hasattr(embedding_model, '_target_device'):
        device = str(embedding_model._target_device)
    else:
        device = 'cpu'

    # Set CUDA allocator
    if torch.cuda.is_available() and 'cuda' in device:
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

    # Encode documents in batches and store immediately (similar to embedding_qa.py)
    # Use extremely small batches to minimize memory footprint and avoid OOM
    # Critical: vLLM server shares GPU with embedding model, so we need minimal batch sizes
    encoding_batch_size = 2  # Extremely small batch size to avoid OOM (reduced from 4)
    internal_batch_size = 1  # Minimal internal batch (reduced from 2)
    use_cpu_fallback = False

    for i in range(0, len(documents), encoding_batch_size):
        end_idx = min(i + encoding_batch_size, len(documents))
        batch_docs = documents[i:end_idx]
        encoding_device = 'cpu' if use_cpu_fallback else device

        try:
            # Encode batch
            batch_embeddings = embedding_model.encode(
                batch_docs,
                convert_to_numpy=True,
                show_progress_bar=False,
                device=encoding_device,
                batch_size=internal_batch_size
            )

        except torch.cuda.OutOfMemoryError:
            print(f"[Construct] GPU OOM on batch {i//encoding_batch_size + 1}, switching to CPU for remaining batches")
            # Clear GPU cache
            torch.cuda.empty_cache()
            gc.collect()

            # Retry this batch on CPU
            batch_embeddings = embedding_model.encode(
                batch_docs,
                convert_to_numpy=True,
                show_progress_bar=False,
                device='cpu',
                batch_size=internal_batch_size
            )

            # Use CPU for all remaining batches
            use_cpu_fallback = True

        # Store this batch immediately to ChromaDB to avoid memory accumulation
        batch_ids = ids[i:end_idx]
        batch_metadatas = metadatas[i:end_idx]

        collection.add(
            ids=batch_ids,
            documents=batch_docs,
            embeddings=batch_embeddings.tolist(),
            metadatas=batch_metadatas
        )

        # Explicitly delete batch_embeddings to free memory immediately
        batch_embeddings = None
        batch_docs = None
        batch_ids = None
        batch_metadatas = None

        # Aggressive GPU cache clearing after each batch
        if torch.cuda.is_available() and 'cuda' in device:
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
            gc.collect()
            # Give GPU time to release memory
            import time
            time.sleep(0.1)
        else:
            gc.collect()

        device_label = "CPU" if use_cpu_fallback else "GPU"
        print(f"[Construct] Encoded and stored batch {i//encoding_batch_size + 1}/{(len(documents) + encoding_batch_size - 1)//encoding_batch_size} on {device_label}")

    print(f"[Construct] ✅ Stored all {len(documents)} turn embeddings in ChromaDB")

    return collection


async def construct_state_memory(
    trajectory: List[Dict[str, Any]],
    task: str,
    episode_id: str,
    chunk_size: int,
    enable_state_memory_summary: bool,
    call_llm_func,
    embedding_model=None
) -> Dict[str, Any]:
    """
    Convert a long trajectory into turn-wise memory structure with embeddings.
    Uses turn-wise approach where each turn is embedded separately (similar to embedding_qa.py).

    Args:
        trajectory: List of trajectory turns with turn_idx, action, observation
        task: Task description
        episode_id: Episode identifier
        chunk_size: Size of each chunk (deprecated, kept for compatibility)
        enable_state_memory_summary: If True, build state memory summary; if False, only do embedding
        call_llm_func: LLM calling function
        embedding_model: Shared embedding model for encoding turns

    Returns:
        Dictionary with 'state_mem', 'trajectory_data', 'num_turns', 'chunk_graph',
        'chroma_client', 'collection', 'chroma_dir'
    """
    # Limit trajectory length to prevent OOM
    MAX_TURNS = 5000
    original_num_turns = len(trajectory)
    if original_num_turns > MAX_TURNS:
        print(f"[Construct] ⚠️  Warning: Trajectory has {original_num_turns} turns, truncating to {MAX_TURNS} for memory safety")
        trajectory = trajectory[:MAX_TURNS]

    num_turns = len(trajectory)
    trajectory_data = {
        'trajectory': trajectory,
        'task': task,
        'episode_id': episode_id
    }
    trajectory_text = format_trajectory_for_llm(trajectory, task)

    print(f"[Construct] Using turn-wise embedding approach with {num_turns} turns")

    # Build embeddings if embedding_model is provided
    chroma_client = None
    collection = None
    chroma_dir = None

    if embedding_model is not None:
        # Create ChromaDB client
        chroma_dir = tempfile.mkdtemp(prefix="ama_turn_chroma_")
        chroma_client = chromadb.PersistentClient(path=chroma_dir)

        # Create unique collection name
        collection_name = f"ama_turns_{uuid.uuid4().hex[:8]}"
        collection = await build_turn_embeddings(trajectory, embedding_model, chroma_client, collection_name)
        print(f"[Construct] Created ChromaDB collection: {collection_name}")
    else:
        print(f"[Construct] No embedding model provided, skipping turn embeddings")

    # Optional state memory summary (legacy support)
    state_mem = trajectory_text  # Use full trajectory text as state memory

    # Create a simple chunk_graph for compatibility (turn-based, not chunk-based)
    # Each turn is a "chunk" with sequential edges
    chunk_graph = {
        'nodes': [
            {
                'id': i,
                'text': f"Turn {turn.get('turn_idx', i)}:\n  Action: {turn.get('action', '')}\n  Observation: {turn.get('observation', '')}",
                'start_char': 0,
                'end_char': 0,
                'length': 0,
                'turn_indices': [i],
                'chunk_in_turn': 0,
                'order': i
            }
            for i, turn in enumerate(trajectory)
        ],
        'edges': [
            {
                'from': i - 1,
                'to': i,
                'type': 'sequential'
            }
            for i in range(1, num_turns)
        ],
        'num_chunks': num_turns
    }

    return {
        'state_mem': state_mem,
        'trajectory_data': trajectory_data,
        'num_turns': num_turns,
        'chunk_graph': chunk_graph,
        'chroma_client': chroma_client,
        'collection': collection,
        'chroma_dir': chroma_dir
    }
