"""
Retrieval functions for Memory Agent - Turn-wise approach with BM25 pre-filtering
"""
import json
import os
import re
import torch
import gc
import numpy as np
from typing import Dict, List, Any, Optional
from rank_bm25 import BM25Okapi
from .utils import extract_code_from_response
from .prompt import (
    CODE_GENERATION_PROMPT_TEMPLATE,
    CHUNK_SUFFICIENCY_JUDGMENT_PROMPT_TEMPLATE,
    ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE,
    ANSWER_WITHOUT_RETRIEVAL_PROMPT_TEMPLATE,
    CHUNKED_ANSWER_PROMPT_TEMPLATE
)


def build_ordered_trajectory_from_turns(
    trajectory_data: Dict[str, Any],
    relevant_turn_indices: List[int],
    max_chars: int = 16384
) -> str:
    """
    Build an ordered trajectory string from relevant turn indices.

    Args:
        trajectory_data: Trajectory data containing trajectory
        relevant_turn_indices: List of turn indices to include
        max_chars: Maximum characters in the output (default: 16384)

    Returns:
        Ordered trajectory string
    """
    if not relevant_turn_indices:
        return ""

    trajectory = trajectory_data.get('trajectory', [])

    # Sort turn indices to maintain chronological order
    sorted_indices = sorted(relevant_turn_indices)

    # Build trajectory text from relevant turns
    ordered_parts = []
    for turn_idx in sorted_indices:
        # Find the turn with this index
        turn = None
        for t in trajectory:
            if t.get('turn_idx') == turn_idx:
                turn = t
                break

        if turn:
            turn_text = f"Turn {turn_idx}:\n"
            turn_text += f"  Action: {turn.get('action', '')}\n"
            turn_text += f"  Observation: {turn.get('observation', '')}\n"
            ordered_parts.append(turn_text)

    # Join and truncate if necessary
    full_text = "\n".join(ordered_parts)

    if len(full_text) > max_chars:
        # Truncate to fit within max_chars
        full_text = full_text[:max_chars] + "\n...(truncated)"

    return full_text


async def execute_code(code: str, trajectory_text_json: str, ray_workers: List, worker_idx_container: List[int], timeout: float = 40.0) -> Optional[Any]:
    """Execute Python code using Ray worker or direct execution."""
    # Try Ray worker first if available
    if ray_workers:
        import base64
        from .utils import _await_ray_object_ref

        worker = ray_workers[worker_idx_container[0] % len(ray_workers)]
        worker_idx_container[0] += 1
        trajectory_b64 = base64.b64encode(trajectory_text_json.encode('utf-8')).decode('ascii')

        script = f"""
import json
import base64
import sys
trajectory_b64 = '''{trajectory_b64}'''
trajectory_json = base64.b64decode(trajectory_b64).decode('utf-8')

# Track if user code printed anything
_original_stdout_write = sys.stdout.write
_user_printed = False

def _tracking_write(text):
    global _user_printed
    if text and text.strip():
        _user_printed = True
    return _original_stdout_write(text)

sys.stdout.write = _tracking_write

# Execute user code
{code}

# Restore original stdout
sys.stdout.write = _original_stdout_write

# If user didn't print anything, auto-print result as fallback
if not _user_printed and 'result' in locals():
    try:
        print(json.dumps(result, indent=2, ensure_ascii=False))
    except (TypeError, ValueError):
        print(str(result))
"""
        obj_ref = worker.run.remote(script=script, timeout=timeout)
        output = await _await_ray_object_ref(obj_ref, timeout_seconds=timeout + 2.0)

        if output and output != "timeout" and not output.startswith("error:"):
            output_stripped = output.strip()
            # First try to parse last JSON block (for backward compatibility)
            try:
                # Try to find JSON in output (could be after explanation text)
                lines = output_stripped.split('\n')
                # Try to parse from the end to find JSON
                for i in range(len(lines) - 1, -1, -1):
                    try:
                        # Try multi-line JSON from this line onwards
                        potential_json = '\n'.join(lines[i:])
                        parsed = json.loads(potential_json)
                        print(f"[Retrieve] Successfully parsed JSON from output")
                        return {'output': output_stripped, 'result': parsed}
                    except json.JSONDecodeError:
                        continue
                # If no JSON found, return full output as result
                print(f"[Retrieve] No JSON in output, returning full output")
                return {'output': output_stripped, 'result': output_stripped}
            except Exception as e:
                print(f"[Retrieve] Error parsing output: {e}, returning as-is")
                return {'output': output_stripped, 'result': output_stripped}

    # Fallback to direct execution
    print("[Retrieve] Using direct code execution")
    import io
    import sys

    # Capture stdout
    captured_output = io.StringIO()
    old_stdout = sys.stdout
    sys.stdout = captured_output

    exec_globals = {'trajectory_json': trajectory_text_json, 'json': json, 're': __import__('re'), 'result': None}
    try:
        exec(code, exec_globals)
        result = exec_globals.get('result')

        # Restore stdout
        sys.stdout = old_stdout
        output = captured_output.getvalue()

        # If no output, auto-print result
        if not output.strip() and result is not None:
            try:
                output = json.dumps(result, indent=2, ensure_ascii=False)
            except (TypeError, ValueError):
                output = str(result)

        return {'output': output.strip(), 'result': result}
    except Exception as e:
        sys.stdout = old_stdout
        print(f"[Retrieve] Code execution error: {e}")
        return None


def estimate_tokens(text: str) -> int:
    """Estimate token count (rough approximation: 1 token ≈ 4 characters)."""
    return len(text) // 4


def expand_turns_with_context(
    top_turn_indices: List[int],
    trajectory_data: Dict[str, Any],
    max_tokens: int = 8192
) -> List[int]:
    """
    Expand context around top retrieved turns until reaching max_tokens.
    Prioritize the highest-ranked turn by giving it the most context.

    Args:
        top_turn_indices: List of turn indices sorted by relevance (best first)
        trajectory_data: Trajectory data containing trajectory
        max_tokens: Maximum tokens to retrieve (default: 8192)

    Returns:
        List of turn indices (expanded with context) sorted by turn order
    """
    trajectory = trajectory_data.get('trajectory', [])
    num_turns = len(trajectory)

    if not top_turn_indices or not trajectory:
        return []

    # Start with the highest-ranked turn
    best_turn_idx = top_turn_indices[0]
    selected_indices = {best_turn_idx}

    # Calculate current token count
    def get_turn_text(idx):
        if 0 <= idx < num_turns:
            turn = trajectory[idx]
            return f"Turn {turn.get('turn_idx', idx)}:\n  Action: {turn.get('action', '')}\n  Observation: {turn.get('observation', '')}"
        return ""

    current_tokens = estimate_tokens(get_turn_text(best_turn_idx))

    # Expand context around the best turn (alternating before/after)
    depth_before = 0
    depth_after = 0
    while current_tokens < max_tokens:
        added = False

        # Try to add turn before
        candidate_before = best_turn_idx - depth_before - 1
        if candidate_before >= 0 and candidate_before not in selected_indices:
            turn_text = get_turn_text(candidate_before)
            turn_tokens = estimate_tokens(turn_text)
            if current_tokens + turn_tokens <= max_tokens:
                selected_indices.add(candidate_before)
                current_tokens += turn_tokens
                depth_before += 1
                added = True

        # Try to add turn after
        candidate_after = best_turn_idx + depth_after + 1
        if candidate_after < num_turns and candidate_after not in selected_indices:
            turn_text = get_turn_text(candidate_after)
            turn_tokens = estimate_tokens(turn_text)
            if current_tokens + turn_tokens <= max_tokens:
                selected_indices.add(candidate_after)
                current_tokens += turn_tokens
                depth_after += 1
                added = True

        # If we can't add any more turns, stop
        if not added:
            break

    # Sort by turn index to maintain chronological order
    result = sorted(list(selected_indices))

    print(f"[Retrieve] Expanded from best turn {best_turn_idx} with depth_before={depth_before}, depth_after={depth_after}")
    print(f"[Retrieve] Total {len(result)} turns, estimated {current_tokens} tokens")

    return result


def select_best_chunk_by_bm25(
    query: str,
    turn_indices: List[int],
    trajectory_data: Dict[str, Any],
    chunk_size: int = 2000,
    max_tokens: int = 8192
) -> List[int]:
    """
    When retrieved turns are too long, split into chunks and use BM25 to select the best chunk.

    Args:
        query: User query
        turn_indices: List of turn indices to consider
        trajectory_data: Trajectory data containing trajectory
        chunk_size: Size of each chunk in characters (default: 2000)
        max_tokens: Maximum tokens in output (default: 8192)

    Returns:
        List of turn indices in the best chunk, sorted by turn order
    """
    trajectory = trajectory_data.get('trajectory', [])

    # Build text from turns
    turn_texts = []
    for idx in sorted(turn_indices):
        if 0 <= idx < len(trajectory):
            turn = trajectory[idx]
            turn_text = f"Turn {turn.get('turn_idx', idx)}:\n  Action: {turn.get('action', '')}\n  Observation: {turn.get('observation', '')}"
            turn_texts.append((idx, turn_text))

    # Concatenate all turn texts
    full_text = "\n".join([text for _, text in turn_texts])

    # Split into chunks
    chunks = []
    chunk_turn_map = []  # Map chunk to turn indices
    current_pos = 0

    while current_pos < len(full_text):
        chunk_end = min(current_pos + chunk_size, len(full_text))
        chunk_text = full_text[current_pos:chunk_end]

        # Determine which turns are in this chunk
        chunk_start_in_full = current_pos
        chunk_end_in_full = chunk_end

        # Calculate which turns fall within this chunk
        char_pos = 0
        turns_in_chunk = []
        for turn_idx, turn_text in turn_texts:
            turn_start = char_pos
            turn_end = char_pos + len(turn_text) + 1  # +1 for newline

            # Check if turn overlaps with chunk
            if turn_start < chunk_end_in_full and turn_end > chunk_start_in_full:
                turns_in_chunk.append(turn_idx)

            char_pos = turn_end

        chunks.append(chunk_text)
        chunk_turn_map.append(turns_in_chunk)
        current_pos = chunk_end

    # Use BM25 to find the best chunk
    tokenized_chunks = [chunk.lower().split() for chunk in chunks]
    bm25 = BM25Okapi(tokenized_chunks)
    tokenized_query = query.lower().split()
    scores = bm25.get_scores(tokenized_query)

    # Get the best chunk
    best_chunk_idx = int(scores.argmax())
    best_turns = chunk_turn_map[best_chunk_idx]

    print(f"[Retrieve] Split into {len(chunks)} chunks, selected chunk {best_chunk_idx} with {len(best_turns)} turns")

    # Return turns from the best chunk, sorted by order
    return sorted(best_turns)


def retrieve_turns_by_embedding(
    query: str,
    collection,
    embedding_model,
    top_k: int = 5
) -> List[Dict[str, Any]]:
    """
    Retrieve top-k most relevant turns using embedding similarity.

    Args:
        query: User query
        collection: ChromaDB collection
        embedding_model: Embedding model
        top_k: Number of top turns to retrieve (default: 5)

    Returns:
        List of retrieved turns with metadata, sorted by relevance (best first)
    """
    if collection is None or collection.count() == 0:
        print("[Retrieve] No turns available")
        return []

    print(f"[Retrieve] Retrieving top-{top_k} turns by embedding")

    # Auto-detect device
    if hasattr(embedding_model, 'device'):
        device = str(embedding_model.device)
    elif hasattr(embedding_model, '_target_device'):
        device = str(embedding_model._target_device)
    else:
        device = 'cpu'

    # Set CUDA allocator
    if torch.cuda.is_available() and 'cuda' in device:
        os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
        gc.collect()

    # Encode query
    try:
        query_embedding = embedding_model.encode(
            [query],
            convert_to_numpy=True,
            show_progress_bar=False,
            device=device,
            batch_size=1
        )[0]
    except torch.cuda.OutOfMemoryError:
        print("[Retrieve] GPU OOM, falling back to CPU")
        torch.cuda.empty_cache()
        gc.collect()
        query_embedding = embedding_model.encode(
            [query],
            convert_to_numpy=True,
            show_progress_bar=False,
            device='cpu',
            batch_size=1
        )[0]

    # Clear cache
    if torch.cuda.is_available() and 'cuda' in device:
        torch.cuda.empty_cache()
        gc.collect()

    # Query all turns
    results = collection.query(
        query_embeddings=[query_embedding.tolist()],
        n_results=min(top_k, collection.count()),
        include=['documents', 'metadatas', 'distances']
    )

    # Extract results
    retrieved_turns = []
    if results and 'documents' in results and results['documents']:
        for i, (doc, metadata, distance) in enumerate(zip(
            results['documents'][0],
            results['metadatas'][0],
            results['distances'][0]
        )):
            turn_idx = metadata.get('turn_idx', i)

            retrieved_turns.append({
                'chunk_id': turn_idx,  # Keep for compatibility
                'turn_idx': turn_idx,
                'text': doc,
                'distance': distance,
                'metadata': metadata,
                'order': turn_idx,
                'turn_indices': [turn_idx],  # Single turn
                'chunk_in_turn': 0,  # Always 0 for turn-wise
                'relevance_rank': i  # Track original relevance ranking (0 = most relevant)
            })

    # Keep sorted by relevance (distance ascending = more relevant first)
    # ChromaDB already returns results sorted by relevance
    print(f"[Retrieve] Retrieved {len(retrieved_turns)} turns (sorted by relevance)")
    return retrieved_turns


def parse_turn_retrieval_specs(judgment_response: str) -> Dict[str, Any]:
    """
    Parse LLM's turn retrieval specification from NEED_GRAPH response.

    Supports multiple formats:
    A. Adjacent turns with before/after:
       - "turn_5 before=2 after=1"
       - "turn_5 before=2 after=0, turn_8 before=0 after=3"

    B. Turn ranges:
       - "turns 5-10"
       - "turns 3-8, turns 15-20"

    C. Specific individual turns:
       - "turns 3, 7, 12"
       - "turns 5, 8, 15"

    D. Mixed formats:
       - "turn_5 before=2 after=1, turns 15-20, turns 25, 30"

    Returns:
        Dictionary with:
        - 'adjacent_specs': {turn_id: {'before': N, 'after': M}}
        - 'range_specs': [(start, end), ...]
        - 'specific_turns': [turn_id1, turn_id2, ...]
    """
    result = {
        'adjacent_specs': {},
        'range_specs': [],
        'specific_turns': []
    }

    # Pattern A: turn_<id> before=<N> after=<M>
    pattern_adjacent = r'turn_(\d+)\s+before=(\d+)\s+after=(\d+)'
    matches = re.findall(pattern_adjacent, judgment_response, re.IGNORECASE)
    for match in matches:
        turn_id = int(match[0])
        before = int(match[1])
        after = int(match[2])
        result['adjacent_specs'][turn_id] = {'before': before, 'after': after}

    # Pattern B: turns X-Y (ranges)
    pattern_range = r'turns?\s+(\d+)\s*-\s*(\d+)'
    matches = re.findall(pattern_range, judgment_response, re.IGNORECASE)
    for match in matches:
        start = int(match[0])
        end = int(match[1])
        result['range_specs'].append((start, end))

    # Pattern C: turns X, Y, Z (specific turns)
    # First, remove ranges and adjacent specs to avoid double matching
    temp_response = re.sub(pattern_adjacent, '', judgment_response)
    temp_response = re.sub(pattern_range, '', temp_response)

    # Match: turns followed by comma-separated numbers
    pattern_specific = r'turns?\s+((?:\d+\s*,\s*)*\d+)'
    matches = re.findall(pattern_specific, temp_response, re.IGNORECASE)
    for match in matches:
        # Parse comma-separated turn indices
        turn_indices = [int(x.strip()) for x in match.split(',') if x.strip().isdigit()]
        result['specific_turns'].extend(turn_indices)

    # Remove duplicates from specific_turns
    result['specific_turns'] = sorted(list(set(result['specific_turns'])))

    return result


def expand_turn_specs_to_indices(
    turn_specs: Dict[str, Any],
    trajectory_data: Dict[str, Any],
    initial_turn_indices: List[int]
) -> List[int]:
    """
    Expand turn retrieval specifications into a list of turn indices.

    Args:
        turn_specs: Output from parse_turn_retrieval_specs
        trajectory_data: Trajectory data
        initial_turn_indices: Initially retrieved turn indices

    Returns:
        List of turn indices to retrieve (sorted, no duplicates)
    """
    trajectory = trajectory_data.get('trajectory', [])
    num_turns = len(trajectory)

    all_indices = set()

    # Process adjacent specs (turn_X before=N after=M)
    for turn_id, spec in turn_specs['adjacent_specs'].items():
        before = spec['before']
        after = spec['after']

        # Add turns before
        for i in range(max(0, turn_id - before), turn_id):
            all_indices.add(i)

        # Add the turn itself
        all_indices.add(turn_id)

        # Add turns after
        for i in range(turn_id + 1, min(num_turns, turn_id + after + 1)):
            all_indices.add(i)

    # Process range specs (turns X-Y)
    for start, end in turn_specs['range_specs']:
        for i in range(max(0, start), min(num_turns, end + 1)):
            all_indices.add(i)

    # Process specific turns
    for turn_id in turn_specs['specific_turns']:
        if 0 <= turn_id < num_turns:
            all_indices.add(turn_id)

    # If no specs were parsed, return initial turns (fallback)
    if not all_indices:
        return initial_turn_indices

    return sorted(list(all_indices))


def get_adjacent_chunks(
    chunk_ids: List[int],
    chunk_graph: Dict[str, Any],
    direction: str = 'both',
    max_chunks: int = 5,
    chunk_specs: Dict[int, Dict[str, int]] = None
) -> List[Dict[str, Any]]:
    """
    Get adjacent chunks based on the causality graph with per-chunk control.

    Args:
        chunk_ids: List of center chunk IDs
        chunk_graph: Graph structure with nodes and edges
        direction: Global direction ('before', 'after', 'both') - used if chunk_specs is None
        max_chunks: Maximum total chunks to return (global limit)
        chunk_specs: Optional dict mapping chunk_id to {'before': N, 'after': M}
                     Example: {5: {'before': 2, 'after': 1}, 8: {'before': 1, 'after': 2}}
                     If specified, overrides direction parameter for those chunks

    Returns:
        List of adjacent chunks sorted by order
    """
    if not chunk_graph or not chunk_ids:
        return []

    nodes = chunk_graph.get('nodes', [])
    edges = chunk_graph.get('edges', [])

    # Build adjacency map (including multi-hop navigation)
    adjacent_map = {}
    for edge in edges:
        from_id = edge['from']
        to_id = edge['to']
        if from_id not in adjacent_map:
            adjacent_map[from_id] = {'before': [], 'after': []}
        if to_id not in adjacent_map:
            adjacent_map[to_id] = {'before': [], 'after': []}
        adjacent_map[from_id]['after'].append(to_id)
        adjacent_map[to_id]['before'].append(from_id)

    # Collect adjacent chunk IDs with per-chunk control
    adjacent_ids = set()

    for chunk_id in chunk_ids:
        if chunk_id not in adjacent_map:
            continue

        # Determine how many chunks to retrieve in each direction
        if chunk_specs and chunk_id in chunk_specs:
            # Use per-chunk specification
            num_before = chunk_specs[chunk_id].get('before', 0)
            num_after = chunk_specs[chunk_id].get('after', 0)
        else:
            # Use global direction parameter with default of 1 chunk in each direction
            if direction == 'before':
                num_before = 1
                num_after = 0
            elif direction == 'after':
                num_before = 0
                num_after = 1
            else:  # 'both'
                num_before = 1
                num_after = 1

        # Get chunks before (walk backward)
        current_ids = [chunk_id]
        for _ in range(num_before):
            next_ids = []
            for cid in current_ids:
                if cid in adjacent_map:
                    next_ids.extend(adjacent_map[cid]['before'])
            if not next_ids:
                break
            adjacent_ids.update(next_ids)
            current_ids = next_ids

        # Get chunks after (walk forward)
        current_ids = [chunk_id]
        for _ in range(num_after):
            next_ids = []
            for cid in current_ids:
                if cid in adjacent_map:
                    next_ids.extend(adjacent_map[cid]['after'])
            if not next_ids:
                break
            adjacent_ids.update(next_ids)
            current_ids = next_ids

    # Get chunk data
    adjacent_chunks = []
    for node in nodes:
        if node['id'] in adjacent_ids:
            adjacent_chunks.append({
                'chunk_id': node['id'],
                'text': node['text'],
                'order': node.get('order', node['id']),
                'turn_indices': node.get('turn_indices', []),
                'chunk_in_turn': node.get('chunk_in_turn', 0),
                'metadata': {
                    'start_char': node['start_char'],
                    'end_char': node['end_char'],
                    'length': node['length'],
                    'turn_indices': node.get('turn_indices', []),
                    'chunk_in_turn': node.get('chunk_in_turn', 0),
                    'order': node.get('order', node['id'])
                }
            })

    # Sort by order to maintain sequential flow and limit
    adjacent_chunks = sorted(adjacent_chunks, key=lambda x: x['order'])[:max_chunks]
    return adjacent_chunks


async def retrieve_with_query(
    query: str,
    keywords_info: Dict[str, Any],
    relevant_turn_indices: List[int],
    trajectory_data: Dict[str, Any],
    task: str,
    tool_mode_costimize: bool,
    max_iter: int,
    call_llm_func,
    max_tokens: int,
    log_func,
    ray_workers: List,
    worker_idx_container: List[int],
    chunk_graph: Dict[str, Any] = None,
    collection = None,
    embedding_model = None
) -> str:
    """
    Multi-step turn retrieval with answer generation:
    Step 1: Retrieve top-5 turns by embedding similarity
    Step 2: Check token count of top-5 turns
        - If < 8K: Expand context around best turn until reaching 8K
        - If >= 8K: Use BM25 to select best chunk from these turns
    Step 3: Sort final turns by chronological order
    Step 4: Judge if SUFFICIENT, NEED_GRAPH, or NEED_CODE
    Step 5: Generate answer

    Returns:
        Natural language answer to the query
    """

    # ============================================================================
    # STEP 1: Retrieve top-5 turns by embedding similarity
    # ============================================================================
    print(f"[Retrieve] Step 1: Retrieving top-5 turns by embedding")
    top_turns = retrieve_turns_by_embedding(
        query,
        collection,
        embedding_model,
        top_k=5
    )

    if not top_turns:
        print("[Retrieve] No turns retrieved, returning empty answer")
        return "Unable to retrieve relevant information from trajectory."

    # ============================================================================
    # STEP 2: Check token count and expand/filter as needed
    # ============================================================================
    # Get turn indices from top retrieved turns
    top_turn_indices = [turn['turn_idx'] for turn in top_turns]

    # Calculate total tokens of top-5 turns
    top_turns_text = "\n".join([turn['text'] for turn in top_turns])
    top_turns_tokens = estimate_tokens(top_turns_text)

    print(f"[Retrieve] Step 2: Top-5 turns have ~{top_turns_tokens} tokens")

    max_input_tokens = 8192
    final_turn_indices = []

    if top_turns_tokens < max_input_tokens:
        # Expand context around the best turn (first in list = most relevant)
        print(f"[Retrieve] Top-5 turns < {max_input_tokens} tokens, expanding context")
        final_turn_indices = expand_turns_with_context(
            top_turn_indices,
            trajectory_data,
            max_tokens=max_input_tokens
        )
    else:
        # Use BM25 to select best chunk from top-5 turns
        print(f"[Retrieve] Top-5 turns >= {max_input_tokens} tokens, using BM25 to select best chunk")
        final_turn_indices = select_best_chunk_by_bm25(
            query,
            top_turn_indices,
            trajectory_data,
            chunk_size=2000,
            max_tokens=max_input_tokens
        )

    # ============================================================================
    # STEP 3: Build final turn texts sorted by chronological order
    # ============================================================================
    trajectory = trajectory_data.get('trajectory', [])
    final_turns = []
    for turn_idx in final_turn_indices:
        if 0 <= turn_idx < len(trajectory):
            turn = trajectory[turn_idx]
            turn_text = f"Turn {turn.get('turn_idx', turn_idx)}:\n  Action: {turn.get('action', '')}\n  Observation: {turn.get('observation', '')}"
            final_turns.append({
                'turn_idx': turn_idx,
                'text': turn_text,
                'order': turn_idx
            })

    # Format turns for display
    chunks_text = ""
    for turn in final_turns:
        turn_idx = turn['turn_idx']
        chunks_text += f"\n--- Turn {turn_idx} ---\n"
        chunks_text += f"{turn['text']}\n"

    print(f"[Retrieve] Step 3: Final retrieval has {len(final_turns)} turns, sorted by chronological order")

    # Use final_turns as top_chunks for compatibility with rest of the code
    top_chunks = final_turns

    # ============================================================================
    # STEP 4: Judge sufficiency (SUFFICIENT / NEED_GRAPH / NEED_CODE)
    # ============================================================================
    print(f"[Retrieve] Step 4: Judging information sufficiency")
    judgment_prompt = CHUNK_SUFFICIENCY_JUDGMENT_PROMPT_TEMPLATE.format(
        query=query,
        retrieved_chunks=chunks_text
    )

    _, judgment_response = await call_llm_func(judgment_prompt)
    # ============================================================================
    # STEP 5: Take action based on judgment
    # ============================================================================

    # Case A: SUFFICIENT - Extract answer directly from judgment response
    if ("**SUFFICIENT**" in judgment_response):
        print(f"[Retrieve] Step 5: Information sufficient, extracting answer from judgment")

        # Extract answer from judgment response (LLM should have provided it)
        judgment_response_clean = re.sub(r'<think>.*?</think>', '', judgment_response, flags=re.DOTALL).strip()
        answer_match = re.search(r'ANSWER:\s*(.+)', judgment_response_clean, re.IGNORECASE | re.DOTALL)

        if answer_match:
            answer = answer_match.group(1).strip()
            log_func(f"ANSWER: {answer}")
            return answer
        else:
            # Fallback: If LLM didn't provide answer in judgment, generate it
            print(f"[Retrieve] Warning: SUFFICIENT but no ANSWER found, generating answer")
            relevant_mem = f"Retrieved information:\n{chunks_text}"

            # Check if information is too large for a single answer generation
            chunk_size = 12000 # 24K chars for final 
            # Generate answer with retrieved information
            answer_prompt = ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE.format(
                query=query,
                relevant_mem=relevant_mem
            )
            _, llm_response = await call_llm_func(answer_prompt)

            if llm_response:
                llm_response_clean = re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()
                answer_match = re.search(r'ANSWER:\s*(.+)', llm_response_clean, re.IGNORECASE | re.DOTALL)
                if answer_match:
                    answer = answer_match.group(1).strip()
                    log_func(f"ANSWER: {answer}")
                    return answer
                return llm_response_clean
            return "Unable to generate answer from retrieved information."

    # Case B: NEED_GRAPH - Get adjacent/specific turns and generate answer
    elif ("**NEED_GRAPH**") in judgment_response:
        print(f"[Retrieve] Step 5: Need graph navigation, retrieving additional turns")

        # Parse turn retrieval specifications (supports multiple formats)
        turn_specs = parse_turn_retrieval_specs(judgment_response)

        # Get current turn indices
        current_turn_indices = [turn['turn_idx'] for turn in top_chunks]

        # Check if specs were parsed
        has_specs = (turn_specs['adjacent_specs'] or
                     turn_specs['range_specs'] or
                     turn_specs['specific_turns'])

        if has_specs:
            # Use parsed specifications to expand turn indices
            print(f"[Retrieve] Using turn specifications: {turn_specs}")
            additional_turn_indices = expand_turn_specs_to_indices(
                turn_specs,
                trajectory_data,
                current_turn_indices
            )
        else:
            # Fallback: Use old adjacent chunk logic with direction
            print(f"[Retrieve] No specs parsed, using fallback adjacent turn retrieval")
            direction = 'both'
            if 'before' in judgment_response.lower() and 'after' not in judgment_response.lower():
                direction = 'before'
            elif 'after' in judgment_response.lower() and 'before' not in judgment_response.lower():
                direction = 'after'

            turn_ids = current_turn_indices
            adjacent = get_adjacent_chunks(turn_ids, chunk_graph, direction=direction, max_chunks=5)

            # Convert adjacent chunks to turn indices
            additional_turn_indices = [adj['order'] for adj in adjacent]

        # Build texts for additional turns
        if additional_turn_indices:
            additional_turns = []
            for turn_idx in additional_turn_indices:
                if 0 <= turn_idx < len(trajectory):
                    turn = trajectory[turn_idx]
                    turn_text = f"Turn {turn.get('turn_idx', turn_idx)}:\n  Action: {turn.get('action', '')}\n  Observation: {turn.get('observation', '')}"
                    additional_turns.append({
                        'turn_idx': turn_idx,
                        'text': turn_text,
                        'order': turn_idx
                    })

            additional_text = "\n\n--- Additional Retrieved Turns ---\n"
            for turn in additional_turns:
                turn_idx = turn['turn_idx']
                additional_text += f"\n[Turn {turn_idx}]\n{turn['text']}\n"

            relevant_mem = f"Initial retrieved turns:\n{chunks_text}\n{additional_text}"
            print(f"[Retrieve] Retrieved {len(additional_turns)} additional turns")
        else:
            relevant_mem = f"Initial retrieved turns:\n{chunks_text}\n\n(No additional turns found)"
            print(f"[Retrieve] No additional turns retrieved")

        # Check if information is too large for a single answer generation
        chunk_size = 24576  # 24K chars for final answering
        if len(relevant_mem) > chunk_size:
            return await answer_with_chunks(query, relevant_mem, chunk_size, call_llm_func, log_func)

        # Generate answer with retrieved information
        answer_prompt = ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE.format(
            query=query,
            relevant_mem=relevant_mem
        )
        _, llm_response = await call_llm_func(answer_prompt)

        if llm_response:
            llm_response_clean = re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()
            answer_match = re.search(r'ANSWER:\s*(.+)', llm_response_clean, re.IGNORECASE | re.DOTALL)
            if answer_match:
                answer = answer_match.group(1).strip()
                log_func(f"ANSWER: {answer}")
                return answer
            return llm_response_clean
        return "Unable to generate answer from retrieved information."

    # Case C: NEED_CODE - Generate and execute code on full trajectory
    elif ("**NEED_CODE**") in judgment_response:
        print(f"[Retrieve] Step 5: Need code generation, processing full trajectory")

        # Prepare trajectory data
        trajectory_text_json = json.dumps(trajectory_data)
        trajectory = trajectory_data['trajectory']

        # Build trajectory sample for prompt
        trajectory_sample = "\n".join([
            f"Turn {turn.get('turn_idx', i)}:\n  Action: ...\n  Observation: ..."
            for i, turn in enumerate(trajectory[:1])
        ])

        # Generate code prompt
        code_prompt = CODE_GENERATION_PROMPT_TEMPLATE.format(
            query=query,
            task=task,
            trajectory_sample=trajectory_sample
        )
        code_prompt += f"\n\n**Context from Initial Turn Retrieval:**\n{chunks_text[:1000]}\n\nThe above turns were retrieved but are not sufficient. Generate Python code to analyze the full trajectory and extract the needed information.\n"
        print(f"[Retrieve] Generating code with prompt length {len(code_prompt)} chars")
       
        # Generate code
        _, llm_response = await call_llm_func(code_prompt)
        code = extract_code_from_response(llm_response) if llm_response else None

        if code:
            print(f"[Retrieve] Executing generated code")
            print(f"GENERATED_CODE:\n{code}")

            # Execute code
            exec_result = await execute_code(code, trajectory_text_json, ray_workers, worker_idx_container, timeout=80.0)

            if exec_result is not None:
                
                # Extract output and result from execution
                if isinstance(exec_result, dict) and 'output' in exec_result:
                    output_str = exec_result['output']
                    print(f"success: {exec_result['output']}")
                    result_value = exec_result.get('result')
                else:
                    # Backward compatibility: if exec_result is not dict, treat as result
                    output_str = str(exec_result)
                    result_value = exec_result

                # Include code, output, and result for better context
                relevant_mem = f"""Initial turns:
{chunks_text}

Generated Code:
```python
{code}
```

Code Execution Output (with explanation):
{output_str}"""
            else:
                relevant_mem = f"""Initial turns:
{chunks_text}

Generated Code:
```python
{code}
```

Code Execution Output:
(Code execution failed or returned None)"""
        else:
            relevant_mem = f"Initial turns:\n{chunks_text}\n\n(Failed to generate code)"

        # Check if information is too large for a single answer generation
        chunk_size = 24576  # 24K chars for final answering
        if len(relevant_mem) > chunk_size:
            print(f"[Retrieve] Retrieved info too large ({len(relevant_mem)} chars), using chunked answering")
            return await answer_with_chunks(query, relevant_mem, chunk_size, call_llm_func, log_func)

        # Generate answer with retrieved information
        answer_prompt = ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE.format(
            query=query,
            relevant_mem=relevant_mem
        )
        _, llm_response = await call_llm_func(answer_prompt)

        if llm_response:
            llm_response_clean = re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()
            answer_match = re.search(r'ANSWER:\s*(.+)', llm_response_clean, re.IGNORECASE | re.DOTALL)
            if answer_match:
                answer = answer_match.group(1).strip()
                log_func(f"ANSWER: {answer}")
                return answer
            return llm_response_clean
        return "Unable to generate answer from retrieved information."

    # Fallback: Unknown judgment, generate answer from initial turns
    else:
        print(f"[Retrieve] Unknown judgment, using initial turns for answer")
        relevant_mem = f"Retrieved information:\n{chunks_text}"

        # Check if information is too large for a single answer generation
        chunk_size = 24576  # 24K chars for final answering
        if  len(relevant_mem) > chunk_size:
            print(f"[Retrieve] Retrieved info too large ({len(relevant_mem)} chars), using chunked answering")
            return await answer_with_chunks(query, relevant_mem, chunk_size, call_llm_func, log_func)

        # Generate answer with retrieved information
        answer_prompt = ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE.format(
            query=query,
            relevant_mem=relevant_mem
        )
        _, llm_response = await call_llm_func(answer_prompt)

        if llm_response:
            llm_response_clean = re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()
            answer_match = re.search(r'ANSWER:\s*(.+)', llm_response_clean, re.IGNORECASE | re.DOTALL)
            if answer_match:
                answer = answer_match.group(1).strip()
                log_func(f"ANSWER: {answer}")
                return answer
            return llm_response_clean
        return "Unable to generate answer from retrieved information."


async def answer_with_chunks(
    query: str,
    relevant_mem: str,
    chunk_size: int,
    call_llm_func,
    log_func
) -> str:
    """
    Generate answer by processing combined memory (state_mem + relevant_mem) in chunks.
    Each chunk is analyzed with compressed information from previous chunks.
    LLM can either provide answer, compress relevant info, or indicate no relevant info.

    Args:
        query: User query
        state_mem_str: State memory as string
        relevant_mem: Retrieved relevant memory (too large)
        chunk_size: Size of each chunk
        call_llm_func: Async LLM call function
        log_func: Logging function

    Returns:
        Final answer
    """
    # Combine state_mem and relevant_mem into one unified memory
    combined_memory = f"=== Retrieved Detailed Information ===\n{relevant_mem}"

    print(f"[Retrieve] Combined memory size: {len(combined_memory)} chars, chunk size: {chunk_size}")

    # Split combined memory into chunks
    chunks = []
    current_pos = 0
    while current_pos < len(combined_memory):
        chunk = combined_memory[current_pos:current_pos + chunk_size]
        chunks.append(chunk)
        current_pos += chunk_size

    print(f"[Retrieve] Split combined memory into {len(chunks)} chunks")

    # Track compressed information from previous chunks
    compressed_info = ""
    MAX_COMPRESSED_INFO_SIZE = 8192  # Prevent OOM from accumulated compressed info (~2K tokens)

    for i, chunk in enumerate(chunks):
        print(f"[Retrieve] Processing chunk {i+1}/{len(chunks)} for answering...")

        # Build compressed info section for prompt
        compressed_info_section = ""
        if compressed_info:
            compressed_info_section = f"**Compressed Relevant Information from Previous Chunks:**\n{compressed_info}\n\n"

        # Create prompt using the template
        chunk_prompt = CHUNKED_ANSWER_PROMPT_TEMPLATE.format(
            query=query,
            compressed_info_section=compressed_info_section,
            chunk_idx=i+1,
            total_chunks=len(chunks),
            chunk_content=chunk
        )

        _, llm_response = await call_llm_func(chunk_prompt)

        llm_response_clean = re.sub(r'<think>.*?</think>', '', llm_response, flags=re.DOTALL).strip()

        # Check if LLM provided a complete answer
        answer_match = re.search(r'ANSWER:\s*(.+)', llm_response_clean, re.IGNORECASE | re.DOTALL)
        if answer_match:
            extracted_answer = answer_match.group(1).strip()
            print(f"[Retrieve] Complete answer provided after chunk {i+1}/{len(chunks)}")
            log_func(f"Answer generated at chunk {i+1}/{len(chunks)}")
            return extracted_answer

        # Check if LLM provided compressed information
        compress_match = re.search(r'COMPRESS:\s*(.+?)(?=\n\n|$)', llm_response_clean, re.IGNORECASE | re.DOTALL)
        if compress_match:
            new_compressed_info = compress_match.group(1).strip()
            print(f"[Retrieve] Chunk {i+1}/{len(chunks)} has relevant info, compressing...")
            # Append to compressed info for next chunk
            if compressed_info:
                compressed_info += f"\n\n[From Chunk {i+1}] {new_compressed_info}"
            else:
                compressed_info = f"[From Chunk {i+1}] {new_compressed_info}"

            # Limit compressed_info size to prevent OOM
            if len(compressed_info) > MAX_COMPRESSED_INFO_SIZE:
                # Keep only the most recent info (last 75% to maintain continuity)
                keep_size = int(MAX_COMPRESSED_INFO_SIZE * 0.75)
                compressed_info = "...(previous info truncated)\n" + compressed_info[-keep_size:]
                print(f"[Retrieve] Trimmed compressed_info to {len(compressed_info)} chars")

            continue

        # Check if no relevant info in this chunk
        if re.search(r'NO_RELEVANT_INFO', llm_response_clean, re.IGNORECASE):
            print(f"[Retrieve] Chunk {i+1}/{len(chunks)} has no relevant info, continuing...")
            continue

        # If last chunk and no answer yet, force answer generation
        if i == len(chunks) - 1:
            print(f"[Retrieve] Last chunk reached, forcing answer generation...")

            # Generate final answer based on all compressed information
            if compressed_info:
                final_prompt = ANSWER_WITH_RETRIEVAL_PROMPT_TEMPLATE.format(
                    query=query,
                    relevant_mem=compressed_info
                )
            else:
                # No relevant info found in any chunk
                final_prompt = ANSWER_WITHOUT_RETRIEVAL_PROMPT_TEMPLATE.format(
                    query=query
                )

            _, final_response = await call_llm_func(final_prompt)

            if final_response:
                final_response_clean = re.sub(r'<think>.*?</think>', '', final_response, flags=re.DOTALL).strip()
                answer_match = re.search(r'ANSWER:\s*(.+)', final_response_clean, re.IGNORECASE | re.DOTALL)
                if answer_match:
                    return answer_match.group(1).strip()
                return final_response_clean

            return "Unable to generate answer from retrieved information."

    return "Unable to generate answer from retrieved information."
