"""
State memory construction functions for Memory Agent
"""
from typing import Dict, List, Any
from .utils import extract_state_memory_from_response, truncate_trajectory_text
from .prompt import COMPRESS_PROMPT_TEMPLATE


def format_single_turn(turn: Dict[str, Any]) -> str:
    """Format single turn into readable text."""
    t = turn.get('turn_idx', 0)
    action = turn.get('action', '')
    observation = turn.get('observation', '')
    return f"Turn {t}:\n  Action: {action}\n  Observation: {observation}\n"


def format_trajectory_for_llm(trajectory: List[Dict[str, Any]], task: str) -> str:
    """Format trajectory into readable text for LLM with maximum detail preservation."""
    lines = []
    for turn in trajectory:
        t = turn.get('turn_idx', 0)
        action = turn.get('action', '')
        observation = turn.get('observation', '')
        lines.append(f"Turn {t}:")
        lines.append(f"  Action: {action}")
        lines.append(f"  Observation: {observation}")
    return "\n".join(lines)


async def process_with_truncation(trajectory_text: str, task: str, chunk_size: int) -> str:
    """Process trajectory with truncation when initial processing fails."""
    # Calculate target length based on chunk size
    target_length = chunk_size  # Use 70% of chunk size for safety

    print(f"[Construct] Truncating trajectory from {len(trajectory_text)} to ~{target_length} chars")
    return truncate_trajectory_text(trajectory_text, target_length)


async def process_trajectory(
    trajectory: List[Dict[str, Any]],
    trajectory_text: str,
    task: str,
    chunk_size: int,
    call_llm_func
) -> str:
    """Process trajectory with optional chunking.

    Args:
        trajectory: Full trajectory data
        trajectory_text: Formatted trajectory text
        task: Task description
        chunk_size: Maximum size for each chunk
        call_llm_func: LLM calling function

    Returns:
        State memory as string, or None if extraction failed
    """
    total_chars = len(trajectory_text)

    # Single chunk processing
    if total_chars <= chunk_size:
        print(f"[Construct] Processing single chunk ({total_chars} chars)")
        compress_prompt = COMPRESS_PROMPT_TEMPLATE.format(
            task=task,
            trajectory_text=trajectory_text,
            previous_state_text=""
        )

        _, llm_response = await call_llm_func(compress_prompt)
        state_mem = extract_state_memory_from_response(llm_response)
        return state_mem

    # Multi-chunk processing: split trajectory_text by character count
    print(f"[Construct] Processing with chunking ({total_chars} chars, chunk_size={chunk_size})")

    # Split trajectory_text directly by chunk_size
    chunks = []
    start_pos = 0

    while start_pos < total_chars:
        end_pos = min(start_pos + chunk_size, total_chars)
        chunk_text = trajectory_text[start_pos:end_pos]

        # Check if this chunk was cut in the middle of a turn
        if start_pos > 0:
            # Find the last complete "Turn X:" before this position to determine context
            import re
            # Look backwards in original text to find which turn we're in
            prefix = trajectory_text[max(0, start_pos - 100):start_pos]
            # Find last "Turn X:" pattern
            turn_matches = list(re.finditer(r'Turn (\d+):', prefix))
            if turn_matches:
                last_turn = turn_matches[-1].group(1)
                chunk_text = f"[Continuation of Turn {last_turn}, previous content truncated]\n{chunk_text}"
            else:
                chunk_text = f"[Continuation from previous chunk]\n{chunk_text}"

        chunks.append(chunk_text)
        start_pos = end_pos

    print(f"[Construct] Split into {len(chunks)} text chunks")

    # Process chunks sequentially, accumulating state
    accumulated_state = ""

    for i, chunk_text in enumerate(chunks):
        # Include previous state in prompt
        previous_state_text = f"Previous State Memory:\n{accumulated_state}" if accumulated_state else ""

        compress_prompt = COMPRESS_PROMPT_TEMPLATE.format(
            task=task,
            trajectory_text=chunk_text,
            previous_state_text=previous_state_text
        )
        print(f"[Construct] Processing chunk {i+1}/{len(chunks)}, chunk_text: {len(chunk_text)} chars, compress_prompt: {len(compress_prompt)} chars")
        _, llm_response = await call_llm_func(compress_prompt)

        state_mem = extract_state_memory_from_response(llm_response)
        accumulated_state = state_mem
        print(f"[Construct] Chunk {i+1} processed. Accumulated state length: {len(accumulated_state) if accumulated_state else 0} chars")

    return accumulated_state if accumulated_state else None


async def construct_state_memory(
    trajectory: List[Dict[str, Any]],
    task: str,
    episode_id: str,
    chunk_size: int,
    enable_state_memory_summary: bool,
    call_llm_func
) -> Dict[str, Any]:
    """
    Convert a long trajectory into state memory and trajectory data.
    1. state_mem: Compressed state memory (if enable_state_memory_summary=True)
    2. trajectory_data: Contains trajectory, task, episode_id

    Args:
        trajectory: List of trajectory turns with turn_idx, action, observation
        task: Task description
        episode_id: Episode identifier
        chunk_size: Size of each chunk for processing
        enable_state_memory_summary: If True, build state memory summary; if False, only do embedding
        call_llm_func: LLM calling function

    Returns:
        Dictionary with 'state_mem', 'trajectory_data', 'num_turns'
    """
    num_turns = len(trajectory)
    trajectory_data = {
        'trajectory': trajectory,
        'task': task,
        'episode_id': episode_id
    }
    trajectory_text = format_trajectory_for_llm(trajectory, task)

    state_mem = None

    # Only build state memory summary if enabled
    if enable_state_memory_summary:
        total_chars = len(trajectory_text)

        state_mem = await process_trajectory(trajectory, trajectory_text, task, chunk_size, call_llm_func)

        # Fallback to truncation if state memory is empty
        if not state_mem:
            state_mem = await process_with_truncation(trajectory_text, task, chunk_size)
    else:
        # Skip state memory summary, use full trajectory text as state memory
        state_mem = trajectory_text

    return {
        'state_mem': state_mem,
        'trajectory_data': trajectory_data,
        'num_turns': num_turns
    }
