#!/usr/bin/env python3
"""
Extract successful records from a complex results.jsonl file, including
problem, CoT, answer, caption, and figure fields.
Only keep records whose status is "success".
"""

import json
import os
import sys
from pathlib import Path
import argparse
import shutil


def extract_results(input_file, output_dir):
    """
    Extract successful records from results.jsonl, including qa, CoT,
    caption, and figure fields.

    Args:
        input_file: Path to the input results.jsonl file.
        output_dir: Path to the output directory.
    """
    # Create output directory and the "figures" subdirectory
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    figures_path = output_path / "figures"
    figures_path.mkdir(parents=True, exist_ok=True)
    
    # Read the input file (JSONL format, one JSON object per line)
    results = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            line = line.strip()
            if line:  # Skip empty lines
                try:
                    item = json.loads(line)
                    results.append(item)
                except json.JSONDecodeError as e:
                    print(f"Warning: Skipping invalid JSON line: {e}", file=sys.stderr)
    
    # Keep only records whose status is "success"
    success_results = []
    for item in results:
        status = item.get("status", "")
        if status == "success":
            success_results.append(item)
    # Extract data and move images
    extracted_data = []
    moved_count = 0
    for item in success_results:
        original_image_path = item.get('plotting', {}).get("figure_path", "")
        new_image_path = ""
        
        # If the image path exists, move it to the "figures" subdirectory in the output folder
        if original_image_path:
            source_path = Path(original_image_path)
            if source_path.exists() and source_path.is_file():
                # Keep the original file name
                target_filename = source_path.name
                target_path = figures_path / target_filename
                
                try:
                    shutil.move(str(source_path), str(target_path))
                    new_image_path = f"figures/{target_filename}"  # Relative path
                    moved_count += 1
                except Exception as e:
                    print(
                        f"Warning: Failed to move image "
                        f"[index={item.get('index', 'unknown')}]: "
                        f"{original_image_path} -> {e}",
                        file=sys.stderr,
                    )
                    new_image_path = original_image_path  # Keep original path
            else:
                print(
                    f"Warning: Image file does not exist "
                    f"[index={item.get('index', 'unknown')}]: {original_image_path}",
                    file=sys.stderr,
                )
                new_image_path = original_image_path  # Keep original path
        
        extracted_item = {
            "index": item.get("index", None),
            "image_path": new_image_path,
            "qa": item.get("visualize_qa", {}).get("question", ""),
            "cot": item.get("visualize_qa", {}).get("cot", ""),
            "answer": item.get("generation", {}).get("answer", ""),
            "caption": item.get('image_quality', {}).get('caption_generation', {}).get('caption', ""),
        }
        extracted_data.append(extracted_item)
    
    # Write results to JSON file
    output_file = output_path / "extracted_results.json"
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(extracted_data, f, ensure_ascii=False, indent=2)
    
    print("Extraction completed!")
    print(f"  Input file: {input_file}")
    print(f"  Total records: {len(results)}")
    print(f"  Successful records: {len(success_results)}")
    print(f"  Failed records: {len(results) - len(success_results)}")
    print(f"  Output file: {output_file}")
    print(f"  Extracted {len(extracted_data)} successful records")
    print(f"  Successfully moved {moved_count} images to {figures_path}")
    
    return extracted_data


def main():
    """Main entry point."""
    parser = argparse.ArgumentParser(
        description="Extract successful results from the input file and generate an output JSON."
    )

    parser.add_argument(
        "-i", "--input",
        required=True,
        help="Path to input file (e.g., data/output/results.jsonl)",
    )
    
    parser.add_argument(
        "-o", "--output",
        required=True,
        help="Path to output directory (e.g., data/output/extracted)",
    )

    args = parser.parse_args()

    input_file = args.input
    output_dir = args.output

    # Check input file
    if not os.path.exists(input_file):
        print(f"Error: input file {input_file} does not exist!", file=sys.stderr)
        sys.exit(1)

    # Run extraction
    extract_results(input_file, output_dir)

if __name__ == "__main__":
    main()

