#!/usr/bin/env python3
"""
Automated LLM Fine-tuning Pipeline with OpenHands SDK

This pipeline automates:
1. Data cleaning with LLM assistance
2. Model training with LlamaFactory
3. Evaluation with OpenCompass
4. Iterative refinement based on validation results
"""

import sys
from pathlib import Path

_current_dir = Path(__file__).resolve().parent
_parent_dir = _current_dir.parent
if str(_parent_dir) not in sys.path:
    sys.path.insert(0, str(_parent_dir))

import argparse
import json
import os
import subprocess
import time
from datetime import datetime
from pathlib import Path
from typing import Callable

import yaml

import litellm
from litellm import completion
from pydantic import SecretStr

from openhands.sdk import LLM, Agent, Conversation, Tool
from openhands.sdk.tool import register_tool
from openhands.tools.file_editor import FileEditorTool
from openhands.tools.terminal import TerminalTool

from finetune_pipeline.config import Config
from finetune_pipeline.logger import setup_logger, OutputCapture
from finetune_pipeline.tools.llama_factory import LlamaFactoryTool
from finetune_pipeline.tools.opencompass_eval import OpenCompassTool


# Global variable to track total retry wait time (not counted towards pipeline timeout)
_total_retry_wait_time = 0.0


def get_retry_wait_time() -> float:
    """Get total accumulated retry wait time in seconds."""
    return _total_retry_wait_time


def reset_retry_wait_time() -> None:
    """Reset retry wait time counter (call at pipeline start)."""
    global _total_retry_wait_time
    _total_retry_wait_time = 0.0


def run_conversation_with_retry(
    config: "Config",
    create_agent_fn: "Callable[[LLM], Agent]",
    workspace: str,
    message: str,
    max_retries: int = 100,
    logger=None,
) -> "Conversation":
    """Run conversation with automatic retry on transient errors.
    
    Each retry creates a new LLM/Agent/Conversation with potentially different
    base_url and model (load balancing).
    
    Handles common transient errors like:
    - LLM API timeout (504, 503)
    - Rate limiting (429)
    - Temporary network issues
    
    Args:
        config: Pipeline config for creating LLM
        create_agent_fn: Function to create agent from LLM (e.g., create_data_agent)
        workspace: Workspace path for conversation
        message: Message to send to the conversation
        max_retries: Maximum number of retry attempts (default: 100)
        logger: Optional logger for logging retry attempts
    
    Returns:
        The successful Conversation object
    
    Retry delay: Linear backoff (1s, 2s, 3s, 4s, ...)
    Note: Retry wait time is NOT counted towards pipeline timeout.
    """
    global _total_retry_wait_time
    
    from openhands.sdk.conversation.exceptions import ConversationRunError
    from openhands.sdk.llm.exceptions.types import (
        LLMTimeoutError,
        LLMRateLimitError,
        LLMServiceUnavailableError,
    )
    
    last_error = None
    last_conv = None
    
    for attempt in range(max_retries + 1):
        # Create new LLM/Agent/Conversation for each attempt (load balancing)
        llm, base_url, model = setup_llm(config)
        agent = create_agent_fn(llm)
        conv = Conversation(agent=agent, workspace=workspace)
        conv.send_message(message)
        last_conv = conv
        
        if logger and attempt > 0:
            logger.info(f"Retry attempt {attempt + 1}: endpoint={base_url}, model={model}")
        
        try:
            conv.run()
            return conv  # Success
        except (LLMTimeoutError, LLMRateLimitError, LLMServiceUnavailableError) as e:
            last_error = e
            if attempt < max_retries:
                wait_time = attempt + 1  # Linear backoff: 1s, 2s, 3s, 4s, ...
                if logger:
                    logger.warning(
                        f"LLM API error (attempt {attempt + 1}/{max_retries + 1}): {type(e).__name__}. "
                        f"Retrying with new endpoint in {wait_time}s... (retry time not counted towards timeout)"
                    )
                else:
                    print(
                        f"[Retry] LLM API error (attempt {attempt + 1}/{max_retries + 1}): {type(e).__name__}. "
                        f"Retrying in {wait_time}s..."
                    )
                time.sleep(wait_time)
                _total_retry_wait_time += wait_time  # Track retry wait time
            else:
                if logger:
                    logger.error(f"LLM API error after {max_retries + 1} attempts: {e}")
                raise
        except ConversationRunError as e:
            # Check if it's a retryable error wrapped in ConversationRunError
            error_msg = str(e).lower()
            is_retryable = any(
                keyword in error_msg 
                for keyword in ["timeout", "504", "503", "429", "rate limit", "service unavailable"]
            )
            
            if is_retryable and attempt < max_retries:
                last_error = e
                wait_time = attempt + 1  # Linear backoff: 1s, 2s, 3s, 4s, ...
                if logger:
                    logger.warning(
                        f"Retryable error (attempt {attempt + 1}/{max_retries + 1}): {type(e).__name__}. "
                        f"Retrying with new endpoint in {wait_time}s... (retry time not counted towards timeout)"
                    )
                else:
                    print(
                        f"[Retry] Retryable error (attempt {attempt + 1}/{max_retries + 1}). "
                        f"Retrying in {wait_time}s..."
                    )
                time.sleep(wait_time)
                _total_retry_wait_time += wait_time  # Track retry wait time
            else:
                raise
    
    # Should not reach here, but return last conv just in case
    return last_conv


def setup_hf_token(config: Config) -> None:
    """Set HuggingFace token as environment variable if configured."""
    hf_token = config.huggingface.get_token()
    if hf_token:
        os.environ["HF_TOKEN"] = hf_token
        print("HuggingFace token configured from config file.")
    else:
        # Check if already logged in via huggingface-cli
        try:
            from huggingface_hub import HfFolder
            if HfFolder.get_token():
                return  # Token available from cache
        except Exception:
            pass
        print(
            "WARNING: No HuggingFace token found. "
            "Gated datasets (e.g., ChemCoTBench) may fail.\n"
            "Set huggingface.hf_token in config or run: huggingface-cli login"
        )


def setup_llm(config: Config) -> tuple[LLM, str, str]:
    """Setup LLM for agent conversations.
    
    Returns:
        tuple: (llm, base_url, model) - the LLM object and the actual base_url/model used
    """
    api_key = config.llm.get_api_key()
    if not api_key:
        raise ValueError(
            "LLM API key is required. Set it in config.yaml (llm.api_key) "
            "or via LLM_API_KEY environment variable."
        )

    # Get values first (random selection happens here)
    base_url = config.llm.get_base_url()
    model = config.llm.get_model()

    llm = LLM(
        model=model,
        api_key=SecretStr(api_key),
        base_url=base_url,
        usage_id="finetune_pipeline",
        prompt_cache_retention=None,  # Disable prompt cache retention not supported by Azure
    )
    return llm, base_url, model


def check_log_for_errors(log_content: str, config: Config) -> dict:
    """Use LLM to analyze log content and detect errors.

    Returns:
        dict with keys:
        - has_error: bool
        - error_type: str (e.g., "runtime_error", "stuck", "none")
        - summary: str (brief description)
    """
    litellm.suppress_debug_info = True

    log_tail = log_content[-3000:] if len(log_content) > 3000 else log_content

    if not log_tail.strip():
        return {"has_error": False, "error_type": "none", "summary": "Log is empty"}

    prompt = f"""Analyze the following Python script execution log and determine if an error occurred.

## Log Content
```
{log_tail}
```

## Criteria
1. **runtime_error**: Python exceptions (Traceback), module import failures, file not found, etc.
2. **stuck**: Long periods without output, signs of infinite loops
3. **none**: Running normally, no errors

## Output Format (Strict JSON)
{{"has_error": true/false, "error_type": "runtime_error"/"stuck"/"none", "summary": "Brief description"}}

Only output JSON, no other content."""

    try:
        resp = completion(
            model=config.model_pool.weak_models[0] if config.model_pool.weak_models else config.llm.get_model(),
            messages=[{"role": "user", "content": prompt}],
            api_base=config.model_pool.api_base or config.llm.get_base_url(),
            api_key=config.model_pool.api_key or config.llm.get_api_key(),
            timeout=30,
        )
        result_text = resp.choices[0].message.content.strip()
        # Extract JSON
        if "{" in result_text and "}" in result_text:
            json_str = result_text[result_text.find("{"):result_text.rfind("}")+1]
            return json.loads(json_str)
    except Exception as e:
        # Fall back to simple rule-based detection when LLM call fails
        if "Traceback" in log_content:
            return {"has_error": True, "error_type": "runtime_error", "summary": f"Detected Traceback (LLM fallback: {e})"}

    return {"has_error": False, "error_type": "none", "summary": "No errors detected"}


def create_data_agent(llm: LLM) -> Agent:
    """Create agent for data cleaning tasks.

    Agent only writes Python scripts for cleaning.
    Script execution is handled by the pipeline directly.
    TerminalTool is provided for data exploration (head, wc, etc.) but NOT for script execution.
    """
    return Agent(
        llm=llm,
        tools=[
            Tool(name=FileEditorTool.name),
            Tool(name=TerminalTool.name, params={"no_change_timeout_seconds": 480}),  # 8 min timeout for data exploration
        ],
    )


def create_train_agent(llm: LLM) -> Agent:
    """Create agent for training tasks."""

    def make_train_tools(conv_state):
        tools = []
        tools.extend(LlamaFactoryTool.create(conv_state))
        return tools

    register_tool("TrainToolSet", make_train_tools)

    return Agent(
        llm=llm,
        tools=[
            Tool(name="TrainToolSet"),
            Tool(name=TerminalTool.name, params={"no_change_timeout_seconds": 480}),  # 8 min timeout
        ],
    )


def create_eval_agent(llm: LLM) -> Agent:
    """Create agent for evaluation tasks."""

    def make_eval_tools(conv_state):
        tools = []
        tools.extend(OpenCompassTool.create(conv_state))
        return tools

    register_tool("EvalToolSet", make_eval_tools)

    return Agent(
        llm=llm,
        tools=[
            Tool(name="EvalToolSet"),
            Tool(name=TerminalTool.name, params={"no_change_timeout_seconds": 480}),  # 8 min timeout
        ],
    )


def start_model_server(model_path: str, port: int = 8000) -> subprocess.Popen | None:
    """Start vLLM server for the trained model."""
    cmd = [
        "python", "-m", "vllm.entrypoints.openai.api_server",
        "--model", model_path,
        "--port", str(port),
        "--trust-remote-code",
    ]

    try:
        process = subprocess.Popen(
            cmd,
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
        )
        # Wait for server to start
        time.sleep(30)
        return process
    except Exception as e:
        print(f"Failed to start model server: {e}")
        return None


def stop_model_server(process: subprocess.Popen | None) -> None:
    """Stop the model server."""
    if process:
        process.terminate()
        process.wait(timeout=10)


def parse_llm_selection(conv: Conversation) -> int | None:
    """Parse the model selection result from LLM.

    Args:
        conv: Conversation object containing LLM response

    Returns:
        Selected iteration number, returns None if parsing fails
    """
    import re

    # Get the last MessageEvent from events (agent's response)
    response = ""
    for event in reversed(list(conv.state.events)):
        if event.kind == "MessageEvent" and event.source == "agent":
            # llm_message.content is a list, get the text from the first element
            if hasattr(event, 'llm_message') and event.llm_message:
                content = event.llm_message.content
                if content and len(content) > 0:
                    response = content[0].text if hasattr(content[0], 'text') else str(content[0])
            break

    if not response:
        return None

    # Try to match SELECTED_ITERATION: <number>
    match = re.search(r'SELECTED_ITERATION:\s*(\d+)', response, re.IGNORECASE)
    if match:
        return int(match.group(1))

    # Fallback: try to match "iteration X" patterns
    match = re.search(r'iteration\s*(\d+)', response, re.IGNORECASE)
    if match:
        return int(match.group(1) or match.group(2))

    return None


def format_duration(seconds: float) -> str:
    """Format duration in human-readable format."""
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        minutes = seconds / 60
        return f"{minutes:.1f}m"
    else:
        hours = seconds / 3600
        return f"{hours:.2f}h"


def load_dataset_readme(data_path: str) -> str | None:
    """Load README.md for the dataset.
    
    Searches for README.md in:
    1. Same directory as data file
    2. Parent directory of data file
    3. Up to 2 levels up from data file
    
    Returns the README content if found, None otherwise.
    """
    path = Path(data_path).resolve()
    
    # Search in current dir and up to 2 levels up
    search_dirs = [path.parent]
    if path.parent.parent:
        search_dirs.append(path.parent.parent)
    if path.parent.parent and path.parent.parent.parent:
        search_dirs.append(path.parent.parent.parent)
    
    for dir_path in search_dirs:
        readme_path = dir_path / "README.md"
        if readme_path.exists():
            try:
                content = readme_path.read_text(encoding="utf-8")
                return content
            except Exception:
                continue
    
    return None


def load_data_preview(data_path: str, num_samples: int = 3) -> str:
    """Load preview of data from various formats (JSON, JSONL, CSV, Parquet).
    
    Returns a JSON string representation of the first num_samples records.
    """
    import csv
    
    path = Path(data_path)
    suffix = path.suffix.lower()
    
    try:
        if suffix == ".parquet":
            try:
                import pandas as pd
                df = pd.read_parquet(data_path)
                raw_data = df.head(num_samples).to_dict(orient="records")
            except ImportError:
                return f"can not read Parquet files: need pandas and pyarrow"
        elif suffix == ".csv":
            raw_data = []
            with open(data_path, "r", encoding="utf-8") as f:
                reader = csv.DictReader(f)
                for i, row in enumerate(reader):
                    if i >= num_samples:
                        break
                    raw_data.append(dict(row))
        elif suffix == ".jsonl":
            raw_data = []
            with open(data_path, "r", encoding="utf-8") as f:
                for i, line in enumerate(f):
                    if i >= num_samples:
                        break
                    if line.strip():
                        raw_data.append(json.loads(line))
        else:
            with open(data_path, "r", encoding="utf-8") as f:
                first_char = f.read(1)
                f.seek(0)
                if first_char == "[":
                    raw_data = json.load(f)[:num_samples]
                else:
                    raw_data = []
                    for i, line in enumerate(f):
                        if i >= num_samples:
                            break
                        if line.strip():
                            raw_data.append(json.loads(line))
        
        return json.dumps(raw_data, ensure_ascii=False, indent=2)
    except Exception as e:
        return f"can not read raw data ({suffix}): {e}"


def get_data_stats(data_path: str) -> dict:
    """Get basic statistics of the cleaned data.
    
    Supports both Alpaca format (instruction/input/output) and 
    ShareGPT format (messages with role/content).
    """
    samples = []
    with open(data_path, "r", encoding="utf-8") as f:
        for line in f:
            if line.strip():
                samples.append(json.loads(line))

    if not samples:
        return {"count": 0, "avg_len": 0, "max_len": 0, "p50_len": 0, "p90_len": 0}

    # Calculate output length distribution - support both formats
    lengths = []
    for s in samples:
        if "output" in s:
            # Alpaca format: {"instruction": ..., "input": ..., "output": ...}
            lengths.append(len(s.get("output", "")))
        elif "messages" in s:
            # ShareGPT format: {"messages": [{"role": "user", ...}, {"role": "assistant", ...}]}
            for msg in s["messages"]:
                if msg.get("role") == "assistant":
                    lengths.append(len(msg.get("content", "")))
        elif "conversations" in s:
            # Alternative ShareGPT format with "conversations" key
            for msg in s["conversations"]:
                if msg.get("from") == "gpt" or msg.get("role") == "assistant":
                    lengths.append(len(msg.get("value", msg.get("content", ""))))
    
    if not lengths:
        return {"count": len(samples), "avg_len": 0, "max_len": 0, "p50_len": 0, "p90_len": 0}
    
    lengths.sort()

    return {
        "count": len(samples),
        "avg_len": sum(lengths) // len(lengths),
        "max_len": max(lengths),
        "p50_len": lengths[len(lengths) // 2],
        "p90_len": lengths[int(len(lengths) * 0.9)],
    }


def save_conversation_log(conv: Conversation, output_path: str, name: str) -> None:
    """Save conversation events to a JSON file."""
    try:
        events = list(conv.state.events)
        log_data = {
            "conversation_id": str(conv.id),
            "name": name,
            "event_count": len(events),
            "events": [event.model_dump() for event in events],
        }
        log_file = Path(output_path) / f"{name}_conversation.json"
        with open(log_file, "w", encoding="utf-8") as f:
            json.dump(log_data, f, indent=2, ensure_ascii=False, default=str)
        print(f"  Saved conversation log: {log_file}")
    except Exception as e:
        print(f"  Warning: Failed to save conversation log: {e}")


def load_checkpoint(resume_path: Path) -> dict | None:
    """Load checkpoint from a previous run."""
    results_file = resume_path / "pipeline_results.json"
    if not results_file.exists():
        print(f"No checkpoint found at {results_file}")
        return None

    with open(results_file) as f:
        checkpoint = json.load(f)

    print(f"Loaded checkpoint from {resume_path}")
    print(f"  - Completed iterations: {len(checkpoint.get('iterations', []))}")
    print(f"  - Best iteration so far: {checkpoint.get('best_iteration_so_far')}")

    return checkpoint


class PipelineTimeout(Exception):
    """Raised when pipeline exceeds configured timeout."""
    pass


def check_timeout(start_time: float, timeout_hours: float) -> None:
    """Check if pipeline has exceeded timeout.

    Args:
        start_time: Pipeline start time from time.time()
        timeout_hours: Timeout in hours (0 = no timeout)

    Raises:
        PipelineTimeout: If timeout exceeded
    """
    if timeout_hours <= 0:
        return  # No timeout configured

    elapsed = time.time() - start_time
    timeout_seconds = timeout_hours * 3600
    if elapsed > timeout_seconds:
        raise PipelineTimeout(
            f"Pipeline exceeded {timeout_hours}h timeout "
            f"(elapsed: {format_duration(elapsed)})"
        )


def run_pipeline(config: Config, resume_from: str | None = None) -> dict:
    """Run the full fine-tuning pipeline with iterative refinement."""
    pipeline_start_time = time.time()
    timeout_hours = config.pipeline.timeout_hours

    # Setup workspace
    workspace = config.workspace or str(Path.cwd())

    # Set LLM Judge environment variables from config (for OpenCompass cascade evaluators)
    if config.evaluation.judge_model:
        os.environ["OC_JUDGE_MODEL"] = config.evaluation.judge_model
    if config.evaluation.judge_api_base:
        os.environ["OC_JUDGE_API_BASE"] = config.evaluation.judge_api_base
    if config.evaluation.judge_api_key:
        os.environ["OC_JUDGE_API_KEY"] = config.evaluation.judge_api_key

    # Note: LLM and agents are created per-iteration for load balancing
    # (randomly selects from multiple base_urls and models each iteration)

    # Initialize state
    iteration = 0
    best_model = None  # Will be set by LLM model selection phase
    best_iteration_so_far = None  # Current best iteration number (1-based), selected periodically by LLM

    # Handle resume
    if resume_from:
        output_base = Path(resume_from)
        if not output_base.exists():
            # Try resolving relative to workspace (support both old and new path formats)
            # Old format: outputs/{timestamp}
            # New format: outputs/{benchmark}/{timestamp}
            output_base = Path(workspace) / "outputs" / resume_from
        if not output_base.exists():
            # Try finding in benchmark subdirectory
            benchmark_name = config.evaluation.benchmarks[0]
            safe_benchmark_name = benchmark_name.replace("/", "_").replace(":", "_")
            output_base = Path(workspace) / "outputs" / safe_benchmark_name / resume_from
        if not output_base.exists():
            raise ValueError(f"Resume path not found: {resume_from}")

        checkpoint = load_checkpoint(output_base)
        if checkpoint:
            # Restore state from checkpoint
            iteration = len(checkpoint.get("iterations", []))
            best_iteration_so_far = checkpoint.get("best_iteration_so_far")  # Restore best iteration
            # best_model will be determined by LLM model selection phase

            results = checkpoint
            # Will log resume info after logger initialization
        else:
            raise ValueError(f"Could not load checkpoint from {output_base}")
    else:
        # Create new output directory with benchmark subdirectory
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        # Use benchmark name as subdirectory
        benchmark_name = config.evaluation.benchmarks[0]  # Use first benchmark name
        # Clean benchmark name (remove special characters)
        safe_benchmark_name = benchmark_name.replace("/", "_").replace(":", "_")
        output_base = Path(workspace) / "outputs" / safe_benchmark_name / timestamp
        output_base.mkdir(parents=True, exist_ok=True)

        results = {
            "iterations": [],
            "best_iteration_so_far": None,  # Current best iteration, selected periodically by LLM
            "test_output": None,
            "best_model": None,
            "config": {
                "max_iterations": config.pipeline.max_iterations,
            },
        }

    # Setup logger after output_base is determined
    logger = setup_logger(output_base, name="pipeline")

    # Start capturing all stdout/stderr to log file
    # This captures OpenHands SDK output (rich logging, tool calls, etc.)
    output_capture = OutputCapture(output_base / "pipeline.log")
    output_capture.start()

    # Log welcome banner
    logger.info("=" * 60)
    logger.info("Automated LLM Fine-tuning Pipeline")
    logger.info(f"Started at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
    if timeout_hours > 0:
        logger.info(f"Timeout: {timeout_hours}h")
    logger.info("=" * 60)
    logger.info(f"Pipeline output directory: {output_base}")
    logger.info(f"Log file: {output_base / 'pipeline.log'}")

    # Log resume info if applicable
    if resume_from:
        logger.info(f"Resuming from iteration {iteration + 1}")

    # Load dataset mapping and resolve path
    datasets_config_path = Path(workspace) / config.data.datasets_config
    if not datasets_config_path.exists():
        raise FileNotFoundError(
            f"Datasets config not found: {datasets_config_path}\n"
            f"Please create a mapping file at: {datasets_config_path}"
        )

    with open(datasets_config_path, "r", encoding="utf-8") as f:
        datasets_mapping = json.load(f)

    dataset_name = config.data.dataset
    if dataset_name not in datasets_mapping:
        available = ", ".join(datasets_mapping.keys())
        raise KeyError(
            f"Dataset '{dataset_name}' not found in {datasets_config_path}\n"
            f"Available datasets: {available}"
        )

    # Resolve dataset path (can be absolute or relative to workspace)
    dataset_path_str = datasets_mapping[dataset_name]
    dataset_path = Path(dataset_path_str)
    if not dataset_path.is_absolute():
        dataset_path = Path(workspace) / dataset_path
    # Normalize path to remove ".." and make it cleaner for the agent
    dataset_path = dataset_path.resolve()

    if not dataset_path.exists():
        raise FileNotFoundError(
            f"Dataset file not found: {dataset_path}\n"
            f"Please check the path in {datasets_config_path}"
        )
    logger.info(f"Dataset: {dataset_name} -> {dataset_path}")

    timeout_triggered = False
    reset_retry_wait_time()  # Reset retry wait time counter at pipeline start
    
    while iteration < config.pipeline.max_iterations:
        # Check timeout at start of each iteration (excluding retry wait time)
        if timeout_hours > 0:
            elapsed = time.time() - pipeline_start_time - get_retry_wait_time()
            remaining = timeout_hours * 3600 - elapsed
            if remaining <= 0:
                msg = f"TIMEOUT: Pipeline exceeded {timeout_hours}h limit (effective elapsed: {format_duration(elapsed)}, retry wait excluded: {format_duration(get_retry_wait_time())})"
                logger.warning(msg)
                logger.info("=" * 60)
                logger.info(msg)
                logger.info("=" * 60)
                timeout_triggered = True
                break

        iteration += 1
        iter_output = output_base / f"iteration_{iteration}"
        iter_output.mkdir(parents=True, exist_ok=True)

        # Note: LLM/Agent creation is now handled by run_conversation_with_retry
        # which creates new LLM/Agent for each conversation (and each retry)

        logger.info(f"=== Starting iteration {iteration}/{config.pipeline.max_iterations} ===")
        logger.info("=" * 60)
        logger.info(f"Iteration {iteration} / {config.pipeline.max_iterations}")
        if timeout_hours > 0:
            effective_elapsed = time.time() - pipeline_start_time - get_retry_wait_time()
            remaining = timeout_hours * 3600 - effective_elapsed
            logger.info(f"Time remaining: {format_duration(remaining)} (retry wait excluded: {format_duration(get_retry_wait_time())})")
        logger.info("=" * 60)

        iter_start_time = time.time()
        iter_result = {"iteration": iteration}

        # Phase 1: Data Cleaning
        logger.info("--- Phase 1: Data Cleaning ---")
        cleaned_data_path = str(iter_output / "cleaned_data.jsonl")
        clean_script_path = str(iter_output / "clean_data.py")

        # Build context for data cleaning
        clean_context = f"""
Raw data: {dataset_path}
Output file: {cleaned_data_path}
Script save path: {clean_script_path}

Current iteration: {iteration} / {config.pipeline.max_iterations}

Input sample limit: {config.data.max_samples}
Please ensure the final output sample count does not exceed this limit (use all if insufficient data).
"""

        # Load and append dataset README if available
        dataset_readme = load_dataset_readme(str(dataset_path))
        if dataset_readme:
            clean_context += f"""
## Dataset Documentation (README)

Below is the complete documentation for this dataset. Please read carefully to understand data format, field meanings, task types, and processing recommendations:

```markdown
{dataset_readme}
```

Please use the information from the above README to understand data structure and processing requirements.
"""

        # Provide previous iteration's cleaned data path (if exists)
        if iteration > 1:
            prev_cleaned = output_base / f"iteration_{iteration-1}" / "cleaned_data.jsonl"
            if prev_cleaned.exists():
                clean_context += f"""
Previous iteration cleaned data: {prev_cleaned}

## Data Source Selection
You can choose to use either "raw data" or "previous iteration cleaned data" as input:
- If the previous iteration performed well, continue optimizing on that basis
- If the previous iteration performed poorly or you want to try a new strategy, recommend starting from raw data
Please judge based on the previous evaluation results.
"""

        # Pass best iteration's script for reference (only if it has evaluation results)
        if best_iteration_so_far and best_iteration_so_far < iteration:
            best_script = output_base / f"iteration_{best_iteration_so_far}" / "clean_data.py"
            if best_script.exists():
                clean_context += f"\nBest iteration (Iteration {best_iteration_so_far}) cleaning script: {best_script}\n"
                clean_context += "(You can read and refer to the successful cleaning strategy)\n"

        # Pass simplified historical evaluation results (best iteration + last 2 iterations)
        if len(results["iterations"]) > 0:
            clean_context += "\n## Historical Evaluation Results Summary\n"
            clean_context += "Note: Only showing current best iteration + last 2 iterations\n\n"

            # Determine which iterations to show
            iterations_to_show = []
            if best_iteration_so_far and best_iteration_so_far <= len(results["iterations"]):
                iterations_to_show.append(best_iteration_so_far)
            # Last 2 iterations
            recent_start = max(1, len(results["iterations"]) - 1)
            for i in range(recent_start, len(results["iterations"]) + 1):
                if i not in iterations_to_show:
                    iterations_to_show.append(i)

            clean_context += "| Iteration | Subtask | Score | Metric Type | Note |\n"
            clean_context += "|-----------|---------|-------|-------------|------|\n"
            for idx in sorted(iterations_to_show):
                iter_data = results["iterations"][idx - 1]
                eval_data = iter_data.get("evaluation", {})
                scores = eval_data.get("scores", {})
                metrics = eval_data.get("metrics", {})
                note = "Current Best" if idx == best_iteration_so_far else ""
                for dataset, score in scores.items():
                    metric_type = metrics.get(dataset, "unknown")
                    clean_context += f"| {idx} | {dataset} | {score} | {metric_type} | {note} |\n"
                    note = ""  
            clean_context += "\n"

        # Pass previous evaluation details to help Data Agent analyze
        if results["iterations"]:
            prev_eval = results["iterations"][-1].get("evaluation", {})
            prev_scores = prev_eval.get("scores", {})
            prev_metrics = prev_eval.get("metrics", {})

            # Add detailed metric table
            if prev_scores:
                clean_context += "\n## Previous Evaluation Detailed Results\n"
                clean_context += "| Subtask | Score | Metric Type |\n"
                clean_context += "|---------|-------|-------------|\n"
                for dataset, score in prev_scores.items():
                    metric_type = prev_metrics.get(dataset, "unknown")
                    clean_context += f"| {dataset} | {score} | {metric_type} |\n"
                clean_context += "\nPlease analyze improvement directions based on metric types and performance of each subtask.\n"

            # Pass previous error samples to help Data Agent analyze failure patterns
            prev_errors = prev_eval.get("error_samples", [])
            if prev_errors:
                clean_context += "\n## Previous Model Error Samples (First 3)\n"
                for i, err in enumerate(prev_errors[:3], 1):
                    clean_context += f"{i}. Question: {err.get('question', '')[:100]}...\n"
                    clean_context += f"   Gold Answer: {err.get('gold', '')}\n"
                    clean_context += f"   Model Output: {err.get('prediction', '')[:100]}...\n\n"
                clean_context += "Please analyze these errors and adjust the data cleaning strategy.\n"

        # Data processing methods description
        processing_methods = """
## Available Data Processing Methods

**General Methods**:
- Quality filtering: Filter low-quality samples based on length and coherence
- Deduplication: N-gram or embedding-based deduplication
- Diversity sampling: Diverse sampling
- Format normalization: Format standardization

**Reasoning Task-Specific Methods**:
- Difficulty-based filtering: Focus on \"boundary difficulty\" problems
- Answer-consistency filtering: Keep samples with consistent answers
- CoT quality scoring: Evaluate reasoning quality
- Structural health check: Check reasoning structure (progressive depth, backtracking, verification)
"""

        # Simplified CoT guide
        cot_guide = """
## CoT Format Requirements

**All training data output must include reasoning process**

Generation methods:
1. Let LLM provide step-by-step reasoning
2. Prompt example: "Think step by step, then give the final answer"
3. Don't request `<think>` tags in prompt (model may refuse)
4. If `<think>` format needed, wrap with code

**Verification**: Ensure output contains reasoning process, not just the direct answer.
"""

        # Enhanced LLM API guide with call_llm implementation
        llm_api_guide = f"""
## LLM API Call Guide

In autonomous finetune scenarios, data quality varies greatly. **Strongly recommended** to use LLM in the following cases:
- **CoT generation**: Generate step-by-step reasoning for samples
- **Quality scoring**: Evaluate sample quality, filter low-quality data
- **Answer verification**: Verify if generated answers are correct
- **Data rewriting/enhancement**: Improve expression of low-quality samples

### Model Pool Configuration
```python
import os, json
import litellm; litellm.suppress_debug_info = True
from litellm import completion

STRONG_MODELS = {config.model_pool.strong_models}  # CoT generation, complex reasoning
WEAK_MODELS = {config.model_pool.weak_models}      # Simple filtering, quality scoring
API_BASE = "{config.model_pool.api_base}"
API_KEY = "{config.model_pool.api_key}"
API_TIMEOUT = {config.model_pool.timeout}
```

### call_llm function (recommended)
```python
def call_llm(messages, models, start_idx=0, timeout=API_TIMEOUT):
    \"\"\"Load-balanced LLM call with timeout. Use start_idx to distribute across models.\"\"\"
    if not models:
        raise RuntimeError("Model pool is empty.")
    last_err = None
    for i in range(len(models)):
        model = models[(start_idx + i) % len(models)]
        try:
            resp = completion(
                model=model,
                messages=messages,
                api_base=API_BASE,
                api_key=API_KEY,
                drop_params=True,
                timeout=timeout
            )
            return resp.choices[0].message.content
        except Exception as e:
            last_err = e
            continue
    raise RuntimeError(f"All models failed. Last error: {{last_err}}")
```

### Concurrent processing (improve efficiency)
```python
from concurrent.futures import ThreadPoolExecutor

with ThreadPoolExecutor(max_workers={config.model_pool.max_workers}) as executor:
    futures = {{executor.submit(process_sample, i, sample, i % len(STRONG_MODELS)): i
              for i, sample in enumerate(samples)}}
```

### Usage Suggestions
- **STRONG_MODELS**: For CoT generation, complex reasoning tasks
- **WEAK_MODELS**: For simple quality scoring, format checking
- **Load balancing**: Use `start_idx=i % len(models)` for round-robin request distribution
- **Timeout handling**: Default {config.model_pool.timeout}s, skip sample and continue after timeout

### When to Use LLM (Autonomous Decision)
1. Raw data lacks reasoning process → Call LLM to generate CoT
2. Low quality/messy format → Call LLM to rewrite
3. Need to verify answer correctness → Call LLM to validate
4. Poor previous evaluation results → Consider more aggressive LLM data enhancement

Please autonomously decide whether to call LLM based on actual data situation.
"""

        # Stats output format
        stats_format = """
## Output Statistics

At script end, print:
```
========== SUMMARY ==========
Total output samples: {{count}}
Output file: {{path}}
=============================
```
"""

        # Data preview format
        preview_format = """
## Data Preview Requirements

When analyzing data, please **fully display** at least 3 sample contents:

```python
for i, item in enumerate(data[:3]):
    print(f"--- item {{i}} ---")
    for key, value in item.items():
        print(f"{{key}}:")
        if isinstance(value, str):
            print(value)
        else:
            print(json.dumps(value, ensure_ascii=False, indent=2))
    print()
```
"""

        clean_prompt = f"""
Please analyze the raw data and write a Python script for data cleaning.

**Important**: You only need to write the script and save it to the specified path, you don't need to execute the script. Pipeline will execute automatically.

{clean_context}

{preview_format}

{processing_methods}

{cot_guide}

{llm_api_guide}

## Output Format

- Alpaca format: {{"instruction": "question", "input": "", "output": "reasoning process + answer"}}
- Conversation format: {{"messages": [{{"role": "user", "content": "question"}}, {{"role": "assistant", "content": "reasoning process + answer"}}]}}
- output must include reasoning process

{stats_format}

## Important Notes

1. Use file_editor tool to create cleaning script
2. Script path: {clean_script_path}
3. **Do not try to execute the script** - Pipeline will execute automatically
4. Ensure script is complete and directly runnable Python file
5. Terminal tool **only for data exploration**, **prohibited for executing cleaning scripts**

## Terminal Usage Guide (Important)

- **Set timeout=480** (8 minutes): Data files are large, default timeout may not be enough
- **Prioritize lightweight commands**:
  - View first N lines: `head -n 100 file.json`
  - Count lines: `wc -l file.json`
  - View JSON structure: `head -c 5000 file.json` (view first 5000 bytes)
  - Use jq: `cat file.json | head -c 100000 | jq '.[0]'`
- **Avoid**: Loading entire large file in terminal with Python then print, this will cause timeout

Please autonomously decide processing strategy and write the script based on data situation.
"""
        data_conv = run_conversation_with_retry(
            config=config,
            create_agent_fn=create_data_agent,
            workspace=workspace,
            message=clean_prompt,
            logger=logger,
        )
        save_conversation_log(data_conv, str(iter_output), "data_cleaning")

        # ========== Pipeline executes cleaning script (with monitoring and auto-repair) ==========
        script_path = Path(clean_script_path)
        if script_path.exists():
            script_log_path = iter_output / "clean_data.log"
            max_retries = 3
            check_interval = 30  # Check log every 30 seconds
            max_runtime = 3600   # Maximum runtime 1 hour
            script_success = False

            for attempt in range(max_retries):
                logger.info(f"Executing cleaning script (attempt {attempt + 1}/{max_retries}): {script_path}")
                logger.info(f"Real-time log: tail -f {script_log_path}")

                try:
                    with open(script_log_path, "w") as log_file:
                        process = subprocess.Popen(
                            ["python3", "-u", str(script_path)],  # -u disables buffering
                            cwd=str(iter_output),
                            stdout=log_file,
                            stderr=subprocess.STDOUT,
                        )

                        start_time = time.time()
                        error_detected = False
                        last_log_size = 0

                        # Monitoring loop: periodically check logs
                        while process.poll() is None:
                            elapsed = time.time() - start_time
                            if elapsed > max_runtime:
                                process.kill()
                                process.wait()
                                logger.error("Script execution timeout (exceeded 1 hour)")
                                error_detected = True
                                break

                            time.sleep(check_interval)

                            # Use LLM to check for errors in logs
                            if script_log_path.exists():
                                current_size = script_log_path.stat().st_size
                                if current_size > last_log_size:
                                    log_content = script_log_path.read_text()
                                    last_log_size = current_size
                                    # Call LLM to analyze logs
                                    error_check = check_log_for_errors(log_content, config)
                                    if error_check.get("has_error"):
                                        # Wait for complete error output
                                        time.sleep(3)
                                        process.kill()
                                        process.wait()
                                        error_detected = True
                                        logger.warning(f"LLM detected error [{error_check.get('error_type')}]: {error_check.get('summary')}")
                                        break

                        return_code = process.returncode

                    # Read complete log
                    log_content = script_log_path.read_text() if script_log_path.exists() else ""

                    if return_code == 0 and not error_detected:
                        logger.info("Script execution completed (exit code 0)")
                        if log_content:
                            logger.info(f"Script output (last 2000 characters):\n{log_content[-2000:]}")
                        
                        # Check if output file is empty (script may succeed but produce empty output due to logic issues)
                        output_empty = False
                        if Path(cleaned_data_path).exists():
                            with open(cleaned_data_path, "r", encoding="utf-8") as f:
                                output_sample_count = sum(1 for line in f if line.strip())
                            if output_sample_count == 0:
                                output_empty = True
                                logger.warning(f"Script succeeded but output file is empty! This usually indicates a logic issue in the script.")
                        else:
                            output_empty = True
                            logger.warning(f"Script succeeded but output file does not exist!")
                        
                        if not output_empty:
                            script_success = True
                            break
                        else:
                            # Output is empty, need to fix script
                            logger.error("Script succeeded but output is empty, need to fix script logic")
                            if attempt < max_retries - 1:
                                logger.info("Calling Agent to analyze why output is empty and fix script...")
                                script_content = script_path.read_text()
                                
                                # Read raw data preview (supports JSON, JSONL, CSV, Parquet)
                                raw_data_preview = load_data_preview(str(dataset_path), num_samples=5)
                                
                                fix_prompt = f"""
Script execution succeeded (exit code 0) but output file is empty, which means there's a logic issue in the script. Please analyze and fix it.

## Raw Data Preview (first 3 samples)
```json
{raw_data_preview}
```

## Script Log
```
{log_content[-4000:]}
```

## Current Script Content
```python
{script_content}
```

## Requirements
1. Carefully analyze the raw data structure and find the correct field names
2. Check logic in functions like normalize_item / quality_check
3. Fix issues in the script
4. Use file_editor tool to save the fixed script to: {clean_script_path}
5. Do not execute the script, Pipeline will automatically re-execute it
"""
                                fix_conv = run_conversation_with_retry(
                                    config=config,
                                    create_agent_fn=create_data_agent,
                                    workspace=workspace,
                                    message=fix_prompt,
                                    logger=logger,
                                )
                                save_conversation_log(fix_conv, str(iter_output), f"fix_attempt_{attempt + 1}")
                                logger.info("Agent has attempted to fix the script, preparing to retry...")
                            else:
                                logger.error(f"Script output is empty, reached maximum retry attempts ({max_retries})")
                    else:
                        # Execution failed
                        logger.error(f"Script execution failed (exit code {return_code}, error_detected={error_detected})")
                        if log_content:
                            logger.error(f"Error log (last 3000 characters):\n{log_content[-3000:]}")

                        # Clean up incomplete output file
                        if Path(cleaned_data_path).exists():
                            Path(cleaned_data_path).unlink()
                            logger.warning(f"Deleted incomplete output file: {cleaned_data_path}")

                        # Still have retry opportunities, let Agent fix the script
                        if attempt < max_retries - 1:
                            logger.info("Calling Agent to analyze error and fix script...")
                            script_content = script_path.read_text()
                            
                            # Read raw data preview (supports JSON, JSONL, CSV, Parquet)
                            raw_data_preview = load_data_preview(str(dataset_path), num_samples=5)
                            
                            fix_prompt = f"""
Script execution failed, please analyze the error and fix the script.

## Raw Data Preview (first 3 samples)
```json
{raw_data_preview}
```

## Error Log
```
{log_content[-4000:]}
```

## Current Script Content
```python
{script_content}
```

## Requirements
1. Carefully analyze the error cause and raw data structure
2. Fix issues in the script
3. Use file_editor tool to save the fixed script to: {clean_script_path}
4. Do not execute the script, Pipeline will automatically re-execute it
"""
                            fix_conv = run_conversation_with_retry(
                                config=config,
                                create_agent_fn=create_data_agent,
                                workspace=workspace,
                                message=fix_prompt,
                                logger=logger,
                            )
                            save_conversation_log(fix_conv, str(iter_output), f"fix_attempt_{attempt + 1}")
                            logger.info("Agent has attempted to fix the script, preparing to retry...")
                        else:
                            logger.error(f"Script execution failed, reached maximum retry attempts ({max_retries})")

                except Exception as e:
                    logger.error(f"Script execution exception: {e}")
                    if attempt >= max_retries - 1:
                        break

            if not script_success:
                logger.error("Cleaning script final execution failed")
        else:
            logger.warning(f"Cleaning script does not exist: {script_path}")
        # ========== Script execution ended ==========

        # Validate cleaned data
        cleaned_path = Path(cleaned_data_path)
        if not cleaned_path.exists():
            error_msg = f"Data cleaning failed: output file does not exist {cleaned_data_path}"
            logger.error(error_msg)
            iter_result["data_cleaning"] = {
                "output_path": cleaned_data_path,
                "script_path": clean_script_path,
                "success": False,
                "error": error_msg,
            }
            results["iterations"].append(iter_result)
            logger.error(f"ERROR: {error_msg}")
            continue  # Skip to next iteration

        # Count samples
        with open(cleaned_path, "r", encoding="utf-8") as f:
            sample_count = sum(1 for line in f if line.strip())

        if sample_count == 0:
            error_msg = f"Data cleaning failed: output file is empty {cleaned_data_path}"
            logger.error(error_msg)
            iter_result["data_cleaning"] = {
                "output_path": cleaned_data_path,
                "script_path": clean_script_path,
                "success": False,
                "error": error_msg,
                "sample_count": 0,
            }
            results["iterations"].append(iter_result)
            logger.error(f"ERROR: {error_msg}")
            continue  # Skip to next iteration

        logger.info(f"Data cleaning completed: {sample_count} samples")
        logger.info(f"Cleaned data: {sample_count} samples")

        iter_result["data_cleaning"] = {
            "output_path": cleaned_data_path,
            "script_path": clean_script_path,
            "success": True,
            "sample_count": sample_count,
        }

        # Phase 2: Training
        logger.info("--- Phase 2: Training ---")
        model_output = str(iter_output / "model")

        # Get data statistics for smarter parameter selection
        data_stats = get_data_stats(cleaned_data_path)
        logger.info(f"Data stats: {data_stats['count']} samples, avg_len={data_stats['avg_len']}")

        # Build training context with data statistics
        gpu_ids_str = ",".join(str(g) for g in config.training.gpu_ids) if config.training.gpu_ids else "auto-select"
        train_context = f"""
Training data: {cleaned_data_path}
Output directory: {model_output}
Base model: {config.training.base_model}
Available GPUs: {config.training.num_gpus}
Specified GPU IDs: {gpu_ids_str}

Current iteration: {iteration} / {config.pipeline.max_iterations}

## GPU Usage Tips
- Run `nvidia-smi` to check GPU status
- Training will use GPUs specified in config (passed via gpu_ids parameter)
- Actual GPUs used will be saved to training results

## Data Statistics
- Sample count: {data_stats['count']}
- Average length: {data_stats['avg_len']} characters
- P90 length: {data_stats['p90_len']} characters
- Max length: {data_stats['max_len']} characters
"""

        # Pass simplified historical evaluation results (best iteration + last 2 iterations)
        if len(results["iterations"]) > 0:
            train_context += "\n## Historical Evaluation Results Summary\n"
            train_context += "Note: Only showing current best iteration + last 2 iterations\n\n"

            # Determine which iterations to show
            iterations_to_show = []
            if best_iteration_so_far and best_iteration_so_far <= len(results["iterations"]):
                iterations_to_show.append(best_iteration_so_far)
            # Last 2 iterations
            recent_start = max(1, len(results["iterations"]) - 1)
            for i in range(recent_start, len(results["iterations"]) + 1):
                if i not in iterations_to_show:
                    iterations_to_show.append(i)

            train_context += "| Iteration | Subtask | Score | Metric Type | Note |\n"
            train_context += "|-----------|---------|-------|-------------|------|\n"
            for idx in sorted(iterations_to_show):
                iter_data = results["iterations"][idx - 1]
                eval_data = iter_data.get("evaluation", {})
                scores = eval_data.get("scores", {})
                metrics = eval_data.get("metrics", {})
                note = "Current Best" if idx == best_iteration_so_far else ""
                for dataset, score in scores.items():
                    metric_type = metrics.get(dataset, "unknown")
                    train_context += f"| {idx} | {dataset} | {score} | {metric_type} | {note} |\n"
                    note = ""  # Only show note in first row
            train_context += "\n"

        # Pass previous training params for intelligent adjustment
        if results["iterations"]:
            prev_training = results["iterations"][-1].get("training", {})
            prev_params = prev_training.get("params", {})
            if prev_params:
                train_context += f"\n## Previous Training Parameters\n"
                for k, v in prev_params.items():
                    if v is not None:
                        train_context += f"- {k}: {v}\n"

            # Add detailed metric info for Train Agent
            prev_eval = results["iterations"][-1].get("evaluation", {})
            prev_scores = prev_eval.get("scores", {})
            prev_metrics = prev_eval.get("metrics", {})
            if prev_scores:
                train_context += "\n## Previous Evaluation Detailed Results\n"
                train_context += "| Subtask | Score | Metric Type |\n"
                train_context += "|---------|-------|-------------|\n"
                for dataset, score in prev_scores.items():
                    metric_type = prev_metrics.get(dataset, "unknown")
                    train_context += f"| {dataset} | {score} | {metric_type} |\n"
                train_context += "\nPlease adjust training strategy based on metric types and performance of each subtask.\n"

        train_prompt = f"""
Please use LlamaFactory to train the model.

{train_context}

## Training Method Selection

Choose based on data size and GPU resources:

| Condition | Recommended Method |
|-----------|--------------------|
| Samples < 1000 or GPU < 2 | LoRA |
| Samples >= 1000 and GPU >= 2 | Full SFT (if VRAM sufficient) |
| Uncertain | Default LoRA (safer) |

## Parameter Selection Guide

**Based on data size**:
| Sample Count | epochs | batch_size | lora_rank (if using LoRA) |
|--------------|--------|------------|---------------------------|
| <500         | 5-10   | 2          | 8                         |
| 500-2k       | 3-5    | 4          | 16                        |
| 2k-10k       | 2-3    | 8          | 32                        |
| >10k         | 1-2    | 16         | 64                        |

**Based on length (cutoff_len)**:
- cutoff_len should be >= P90 length
- Recommended: max(2048, P90 length * 1.2)

**Fixed parameters** (MVP phase):
- learning_rate: 1e-4
- lora_alpha: lora_rank * 2
- warmup_ratio: 0.1

**Iterative adjustments**:
- Score decreased → halve epochs, or switch to LoRA
- Score stagnant → double lora_rank, or increase epochs

## Validation

After training, confirm:
1. Model files exist in output directory
2. Training log has no ERRORs

## Tool Usage Example
Use llama_factory tool, set:
- gpu_ids: {config.training.gpu_ids}  # Use GPUs specified in config

Please use llama_factory tool to train and explain the reasoning for parameter choices.
"""
        train_conv = run_conversation_with_retry(
            config=config,
            create_agent_fn=create_train_agent,
            workspace=workspace,
            message=train_prompt,
            logger=logger,
        )
        save_conversation_log(train_conv, str(iter_output), "training")

        # Read actual training params from train_config.yaml
        train_config_file = Path(model_output) / "train_config.yaml"
        if train_config_file.exists():
            with open(train_config_file, "r") as f:
                train_config = yaml.safe_load(f)
            iter_result["training"] = {
                "model_path": model_output,
                "params": {
                    "finetuning_type": train_config.get("finetuning_type", "lora"),
                    "lora_rank": train_config.get("lora_rank"),
                    "lora_alpha": train_config.get("lora_alpha"),
                    "epochs": train_config.get("num_train_epochs"),
                    "batch_size": train_config.get("per_device_train_batch_size"),
                    "learning_rate": train_config.get("learning_rate"),
                    "cutoff_len": train_config.get("cutoff_len"),
                },
            }
        else:
            iter_result["training"] = {"model_path": model_output, "params": {}}

        # Phase 3: Start Model Server
        logger.info("--- Phase 3: Starting Model Server ---")
        # For MVP, we assume the model is already served or use existing API
        # In production, uncomment the following:
        # server_process = start_model_server(model_output)

        # Phase 4: Evaluation (Validation Set)
        logger.info("--- Phase 4: Evaluation (Validation) ---")
        eval_output = str(iter_output / "eval_validation")

        # Set fallback data_range for validation (in case Agent forgets to pass it)
        os.environ["OPENCOMPASS_DATA_RANGE"] = config.evaluation.validation_range
        logger.info(f"Set OPENCOMPASS_DATA_RANGE={config.evaluation.validation_range} (validation fallback)")

        eval_prompt = f"""
Please use OpenCompass to evaluate the trained model (validation set).

## Model Information
- Trained model path: {model_output}
- Base model: {config.training.base_model}
- Model type: Auto-detect (LoRA adapter or SFT full model)

## Evaluation Configuration
- Evaluation mode: vllm (local inference)
- Benchmarks: {config.evaluation.benchmarks}
- Output directory: {eval_output}
- GPU IDs: {config.training.gpu_ids}
- CoT post-processing: Enabled (automatically extract answers in <think>...</think> format)

## ⚠️ Important: Data Range Configuration (Must be strictly followed)
**data_range must be precisely set to**: `{config.evaluation.validation_range}`
This is the validation set data range, used to select the first half of the dataset for validation.
Do not use other values, otherwise validation and test sets will overlap!

## LLM Judge Configuration (for llm_judge type evaluation)
- judge_model: "{config.evaluation.judge_model}"
- judge_api_base: "{config.evaluation.judge_api_base}"
- judge_api_key: "{config.evaluation.judge_api_key}"

## Tool Invocation (Please strictly follow these parameters)
Use opencompass tool, set:
- mode: "vllm"
- model_path: "{model_output}"
- base_model: "{config.training.base_model}"
- benchmarks: {config.evaluation.benchmarks}
- data_range: "{config.evaluation.validation_range}"
- output_dir: "{eval_output}"
- use_cot_postprocessor: true
- gpu_ids: {config.training.gpu_ids}
- judge_model: "{config.evaluation.judge_model}"
- judge_api_base: "{config.evaluation.judge_api_base}"
- judge_api_key: "{config.evaluation.judge_api_key}"

Please immediately use opencompass tool for evaluation, data_range parameter must be precisely set to "{config.evaluation.validation_range}".
"""
        eval_conv = run_conversation_with_retry(
            config=config,
            create_agent_fn=create_eval_agent,
            workspace=workspace,
            message=eval_prompt,
            logger=logger,
        )
        save_conversation_log(eval_conv, str(iter_output), "evaluation")

        # Parse evaluation results
        results_file = Path(eval_output) / "results.json"
        scores_detail = {}
        metrics_detail = {}
        error_samples = []
        if results_file.exists():
            with open(results_file) as f:
                eval_results = json.load(f)
                scores_detail = eval_results.get("scores", {})
                metrics_detail = eval_results.get("metrics", {})
                error_samples = eval_results.get("error_samples", [])

        iter_result["evaluation"] = {
            "scores": scores_detail,  # Detailed scores for each subtask
            "metrics": metrics_detail,  # Metric type for each subtask
            "output_path": eval_output,
            "error_samples": error_samples,
        }

        logger.info(f"Iteration {iteration} evaluation completed, scores: {list(scores_detail.keys())}")

        # Phase 5: Record model path, final selection decided by LLM
        logger.info("--- Phase 5: Record Model ---")
        iter_result["model_path"] = model_output
        logger.info(f"Model saved to: {model_output}")

        # Record iteration duration
        iter_duration = time.time() - iter_start_time
        iter_result["duration_seconds"] = iter_duration
        logger.info(f"Iteration {iteration} completed in {format_duration(iter_duration)}")

        results["iterations"].append(iter_result)

        # ========== Perform checkpoint selection every 3 iterations ==========
        if iteration % 3 == 0 and iteration > 0:
            logger.info("--- Checkpoint: LLM Selecting Best Iteration So Far ---")

            # Build candidate iterations for comparison (best iteration + last 3 iterations), only include those with evaluation results
            candidates = []
            if best_iteration_so_far and best_iteration_so_far <= len(results["iterations"]):
                iter_data = results["iterations"][best_iteration_so_far - 1]
                if iter_data.get("evaluation", {}).get("scores"):  # Only include those with evaluation results
                    candidates.append(best_iteration_so_far)
            # Last 3 iterations, also only include those with evaluation results
            recent_start = max(1, iteration - 2)
            for i in range(recent_start, iteration + 1):
                if i not in candidates:
                    iter_data = results["iterations"][i - 1]
                    if iter_data.get("evaluation", {}).get("scores"):  # Only include those with evaluation results
                        candidates.append(i)

            checkpoint_context = "## Iterations to Compare\n\n"
            for idx in sorted(candidates):
                iter_data = results["iterations"][idx - 1]
                eval_data = iter_data.get("evaluation", {})
                scores = eval_data.get("scores", {})
                metrics = eval_data.get("metrics", {})

                label = "(Current Best)" if idx == best_iteration_so_far else ""
                checkpoint_context += f"### Iteration {idx} {label}\n"
                checkpoint_context += "| Subtask | Score | Metric Type |\n"
                checkpoint_context += "|---------|-------|-------------|\n"
                for dataset, score in scores.items():
                    metric_type = metrics.get(dataset, "unknown")
                    checkpoint_context += f"| {dataset} | {score} | {metric_type} |\n"
                checkpoint_context += "\n"

            checkpoint_prompt = f"""
Please select the best performing iteration from the following candidates.

{checkpoint_context}

**Note**: Different metrics have different meanings:
- **accuracy**, **f1**, **recall**, **precision**: higher is better
- **mae**, **mse**, **rmse**: lower is better
- **rouge**, **bleu**: higher is better
- **tanimoto_similarity_larger_means_better**: higher is better

Output format:
SELECTED_ITERATION: <number>
REASON: <brief explanation>
"""

            checkpoint_conv = run_conversation_with_retry(
                config=config,
                create_agent_fn=create_eval_agent,
                workspace=workspace,
                message=checkpoint_prompt,
                logger=logger,
            )
            save_conversation_log(checkpoint_conv, str(output_base), f"checkpoint_iter{iteration}")

            # Parse result
            selected = parse_llm_selection(checkpoint_conv)
            if selected and selected in candidates:
                best_iteration_so_far = selected
                results["best_iteration_so_far"] = best_iteration_so_far
                logger.info(f"Checkpoint: Best iteration so far = {best_iteration_so_far}")
            else:
                logger.warning(f"Checkpoint selection failed, keeping best_iteration_so_far = {best_iteration_so_far}")

            # Save results after checkpoint
            with open(output_base / "pipeline_results.json", "w") as f:
                json.dump(results, f, indent=2)

        # Stop model server
        # stop_model_server(server_process)

    # ==================== LLM Model Selection ====================
    logger.info("=" * 60)
    logger.info("LLM Model Selection")
    logger.info("=" * 60)

    # Determine iterations to compare (best iteration + iterations since last checkpoint), only include those with evaluation results
    candidates = []
    if best_iteration_so_far and best_iteration_so_far <= len(results["iterations"]):
        iter_data = results["iterations"][best_iteration_so_far - 1]
        if iter_data.get("evaluation", {}).get("scores"):  # Only include those with evaluation results
            candidates.append(best_iteration_so_far)
    # Recent iterations (since last checkpoint, or last 3 iterations), also only include those with evaluation results
    total_iters = len(results["iterations"])
    last_checkpoint = (total_iters // 3) * 3
    recent_start = max(1, last_checkpoint + 1, total_iters - 2)
    for i in range(recent_start, total_iters + 1):
        if i not in candidates:
            iter_data = results["iterations"][i - 1]
            if iter_data.get("evaluation", {}).get("scores"):  # Only include those with evaluation results
                candidates.append(i)

    logger.info(f"Final selection candidates: {sorted(candidates)} (best_so_far={best_iteration_so_far})")

    # Build concise evaluation results summary
    selection_context = "## Iterations to Compare\n\n"
    for idx in sorted(candidates):
        iter_data = results["iterations"][idx - 1]
        eval_data = iter_data.get("evaluation", {})
        scores = eval_data.get("scores", {})
        metrics = eval_data.get("metrics", {})
        model_path = iter_data.get("model_path", "")

        label = "(Current Best)" if idx == best_iteration_so_far else ""
        selection_context += f"### Iteration {idx} {label}\n"
        selection_context += f"- Model path: {model_path}\n"
        selection_context += "| Subtask | Score | Metric Type |\n"
        selection_context += "|---------|-------|-------------|\n"
        for dataset, score in scores.items():
            metric_type = metrics.get(dataset, "unknown")
            selection_context += f"| {dataset} | {score} | {metric_type} |\n"
        selection_context += "\n"

    # Let LLM select the best model
    selection_prompt = f"""
Based on the evaluation results of the following candidate iterations, please select the best model.

{selection_context}

Please consider the performance of all subtasks comprehensively and select the best model based on the meaning of each metric.

**Note**: Different metrics have different meanings:
- **accuracy**, **f1**, **recall**, **precision**: higher is better
- **mae**, **mse**, **rmse**: lower is better
- **rouge**, **bleu**: higher is better
- **tanimoto_similarity_larger_means_better**: higher is better

Please carefully analyze the performance of each subtask and select the iteration with the best overall performance.

Output format:
SELECTED_ITERATION: <number>
REASON: <brief explanation>
"""

    selection_conv = run_conversation_with_retry(
        config=config,
        create_agent_fn=create_eval_agent,
        workspace=workspace,
        message=selection_prompt,
        logger=logger,
    )
    save_conversation_log(selection_conv, str(output_base), "model_selection")

    # Parse LLM selection result
    selected_iteration = parse_llm_selection(selection_conv)

    if selected_iteration and selected_iteration in candidates:
        selected_iter_data = results["iterations"][selected_iteration - 1]
        best_model = selected_iter_data.get("model_path")
        logger.info(f"LLM selected iteration {selected_iteration}: {best_model}")
    else:
        # Fallback: select best_iteration_so_far or last valid model
        if best_iteration_so_far and best_iteration_so_far <= len(results["iterations"]):
            selected_iter_data = results["iterations"][best_iteration_so_far - 1]
            best_model = selected_iter_data.get("model_path")
            selected_iteration = best_iteration_so_far
            logger.warning(f"LLM selection failed, fallback to best_iteration_so_far={best_iteration_so_far}: {best_model}")
        else:
            for i, iter_data in enumerate(reversed(results["iterations"]), 1):
                if iter_data.get("model_path"):
                    best_model = iter_data["model_path"]
                    selected_iteration = len(results["iterations"]) - i + 1
                    break
            logger.warning(f"LLM selection failed, fallback to last valid model: {best_model}")

    # ==================== Final Test Set Evaluation ====================
    # Always run test evaluation with best_model (regardless of validation result)
    test_output = None

    if best_model:
        logger.info("=" * 60)
        logger.info("Final Test Set Evaluation")
        logger.info("=" * 60)
        logger.info(f"Using best model: {best_model}")

        test_output = str(output_base / "final_test")

        # Set fallback data_range for test (in case Agent forgets to pass it)
        os.environ["OPENCOMPASS_DATA_RANGE"] = config.evaluation.test_range
        logger.info(f"Set OPENCOMPASS_DATA_RANGE={config.evaluation.test_range} (test fallback)")

        test_prompt = f"""
Please use OpenCompass to evaluate the final test set.

## Model Information
- Trained model path: {best_model}
- Base model: {config.training.base_model}
- Model type: Auto-detect (LoRA adapter or SFT full model)

## Evaluation Configuration
- Evaluation mode: vllm (local inference)
- Benchmarks: {config.evaluation.benchmarks}
- Output directory: {test_output}
- GPU IDs: {config.training.gpu_ids}
- CoT post-processing: Enabled

## ⚠️ Important: Data Range Configuration (Must be strictly followed)
**data_range must be precisely set to**: `{config.evaluation.test_range}`
This is the test set data range, used to select the second half of the dataset for final testing.
Do not use other values, otherwise validation and test sets will overlap!

## LLM Judge Configuration (for llm_judge type evaluation)
- judge_model: "{config.evaluation.judge_model}"
- judge_api_base: "{config.evaluation.judge_api_base}"
- judge_api_key: "{config.evaluation.judge_api_key}"

## Tool Invocation (Please strictly follow these parameters)
Use opencompass tool, set:
- mode: "vllm"
- model_path: "{best_model}"
- base_model: "{config.training.base_model}"
- benchmarks: {config.evaluation.benchmarks}
- data_range: "{config.evaluation.test_range}"
- output_dir: "{test_output}"
- use_cot_postprocessor: true
- gpu_ids: {config.training.gpu_ids}
- judge_model: "{config.evaluation.judge_model}"
- judge_api_base: "{config.evaluation.judge_api_base}"
- judge_api_key: "{config.evaluation.judge_api_key}"

Please immediately use opencompass tool for evaluation, data_range parameter must be precisely set to "{config.evaluation.test_range}".
"""
        test_conv = run_conversation_with_retry(
            config=config,
            create_agent_fn=create_eval_agent,
            workspace=workspace,
            message=test_prompt,
            logger=logger,
        )
        save_conversation_log(test_conv, str(output_base), "final_test_evaluation")

        # Parse test results
        test_results_file = Path(test_output) / "results.json"
        test_scores_detail = {}
        test_metrics_detail = {}
        if test_results_file.exists():
            with open(test_results_file) as f:
                test_eval_results = json.load(f)
                test_scores_detail = test_eval_results.get("scores", {})
                test_metrics_detail = test_eval_results.get("metrics", {})

        logger.info(f"Test evaluation completed, scores: {list(test_scores_detail.keys())}")
    else:
        logger.warning("No model trained successfully, skipping test evaluation")
        test_scores_detail = {}
        test_metrics_detail = {}

    # Calculate total duration
    total_duration = time.time() - pipeline_start_time

    # Final summary
    results["selected_iteration"] = selected_iteration  # Selected iteration number
    results["best_validation_result"] = None  # Complete validation result
    if selected_iteration and 1 <= selected_iteration <= len(results["iterations"]):
        selected_eval = results["iterations"][selected_iteration - 1].get("evaluation", {})
        results["best_validation_result"] = {
            "scores": selected_eval.get("scores", {}),
            "metrics": selected_eval.get("metrics", {}),
        }
    results["test_result"] = {
        "scores": test_scores_detail,
        "metrics": test_metrics_detail,
    } if test_scores_detail else None
    results["test_output"] = test_output
    results["best_model"] = best_model
    results["total_duration_seconds"] = total_duration
    results["timeout_triggered"] = timeout_triggered
    results["end_time"] = datetime.now().isoformat()

    # Save full results
    with open(output_base / "pipeline_results.json", "w") as f:
        json.dump(results, f, indent=2)

    # Save concise summary to configured path
    summary = {
        "run_id": output_base.name,
        "run_path": str(output_base),
        "selected_iteration": selected_iteration,
        "best_model": best_model,
        "test_output": test_output,
        "total_iterations": iteration,
        "timeout_triggered": timeout_triggered,
        "benchmarks": config.evaluation.benchmarks,
        "base_model": config.training.base_model,
        "total_duration": format_duration(total_duration),
        "end_time": datetime.now().isoformat(),
    }

    # Determine summary path
    summary_path = Path(config.pipeline.results_summary_path)
    if not summary_path.is_absolute():
        summary_path = Path(workspace) / summary_path
    summary_path.parent.mkdir(parents=True, exist_ok=True)

    with open(summary_path, "w") as f:
        json.dump(summary, f, indent=2)

    # Create 'latest' symlink in benchmark subdirectory
    benchmark_outputs_dir = output_base.parent  # outputs/{benchmark_name}/
    latest_link = benchmark_outputs_dir / "latest"
    if latest_link.exists() or latest_link.is_symlink():
        latest_link.unlink()
    latest_link.symlink_to(output_base.name)

    # Log final summary
    logger.info("=" * 60)
    logger.info("Pipeline Complete")
    logger.info("=" * 60)
    logger.info(f"Total iterations: {iteration}")
    logger.info(f"Selected iteration: {selected_iteration}")
    logger.info(f"Best model: {best_model}")
    logger.info(f"Results saved to: {output_base}")
    logger.info(f"Summary saved to: {summary_path}")
    logger.info(f"Latest link: {latest_link} -> {output_base.name}")
    logger.info(f"Timeout triggered: {timeout_triggered}")
    logger.info(f"Total runtime: {format_duration(total_duration)}")
    retry_wait = get_retry_wait_time()
    if retry_wait > 0:
        effective_runtime = total_duration - retry_wait
        logger.info(f"  - Effective runtime: {format_duration(effective_runtime)}")
        logger.info(f"  - Retry wait time (excluded from timeout): {format_duration(retry_wait)}")

    # Note: LLM cost tracking is not available with per-conversation load balancing
    # as each conversation creates a new LLM instance
    logger.info("=" * 60)

    # Stop capturing stdout/stderr
    output_capture.stop()

    return results


def main():
    parser = argparse.ArgumentParser(description="Automated LLM Fine-tuning Pipeline")
    parser.add_argument(
        "--config",
        type=str,
        default="configs/default.yaml",
        help="Path to configuration file",
    )
    parser.add_argument(
        "--max-iterations",
        type=int,
        default=None,
        help="Override max iterations from config",
    )
    parser.add_argument(
        "--resume",
        type=str,
        default=None,
        help="Resume from a previous run. Can be full path or timestamp (e.g., '20260115_084758')",
    )
    parser.add_argument(
        "--timeout",
        type=float,
        default=None,
        help="Override timeout hours from config (0 = no timeout)",
    )
    parser.add_argument(
        "--dataset",
        type=str,
        default=None,
        help="Dataset file name in datasets/ directory (e.g., 'train_dataset.json')",
    )
    parser.add_argument(
        "--gpu-ids",
        type=str,
        default=None,
        help="GPU IDs to use, comma-separated (e.g., '0,1' or '2'). Overrides config.",
    )
    parser.add_argument(
        "--max-samples",
        type=int,
        default=None,
        help="Maximum number of samples to select for cleaning (default: 2000). Overrides config.",
    )
    args = parser.parse_args()

    # Load config
    config_path = Path(args.config).resolve()
    if config_path.exists():
        config = Config.from_yaml(config_path)
    else:
        print(f"Config file not found: {config_path}, using defaults")
        config = Config()

    # Set workspace based on config file location
    # Config files are in {workspace}/configs/, so workspace = config_path.parent.parent
    # This allows running from any directory
    if not config.workspace or config.workspace == ".":
        # Config is at {workspace}/configs/xxx.yaml, so parent.parent = workspace
        config.workspace = str(config_path.parent.parent)
    elif not Path(config.workspace).is_absolute():
        # Resolve relative workspace path against config file's parent's parent
        config.workspace = str((config_path.parent.parent / config.workspace).resolve())

    # Override from command line
    if args.max_iterations is not None:
        config.pipeline.max_iterations = args.max_iterations
    if args.timeout is not None:
        config.pipeline.timeout_hours = args.timeout
    if args.dataset is not None:
        config.data.dataset = args.dataset
    if args.gpu_ids is not None:
        # Parse "0,1,2" -> [0, 1, 2]
        config.training.gpu_ids = [int(x.strip()) for x in args.gpu_ids.split(",")]
    if args.max_samples is not None:
        config.data.max_samples = args.max_samples

    # Setup HuggingFace token for gated datasets
    setup_hf_token(config)

    # Run pipeline
    run_pipeline(config, resume_from=args.resume)


if __name__ == "__main__":
    main()
