#!/usr/bin/env python3

import json
import argparse
import sys
from pathlib import Path


def process_jsonl_file(input_file: str, output_file: str) -> None:
    """
    Process a JSONL file according to the specified rules:
    - If 'prompt' field starts with '\n', just strip the prompt
    - Otherwise, construct a new prompt with the specified format
    """
    input_path = Path(input_file)
    output_path = Path(output_file)
    
    if not input_path.exists():
        print(f"Error: Input file {input_path} does not exist.")
        sys.exit(1)
    
    if not input_path.suffix == '.jsonl':
        print(f"Warning: Input file {input_path} does not have .jsonl extension.")
    
    # Read all lines and process them
    processed_lines = []
    modified_count = 0
    
    try:
        with open(input_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:  # Skip empty lines
                    continue
                
                try:
                    data = json.loads(line)
                    original_prompt = data.get('prompt', '')
                    
                    # Apply the transformation rules
                    if original_prompt.startswith('\n'):
                        # Rule 1: If prompt starts with '\n', just strip it
                        data['prompt'] = original_prompt.strip()
                        modified_count += 1
                    elif 'question' in data:
                        # Rule 2: Construct new prompt with specified format
                        question = data['question']
                        data['prompt'] = f'Solve the following math problem step by step:\n\n{question}\n\nIn the end, provide only the final numerical answer.'
                        modified_count += 1
                    else:
                        raise ValueError(f"Invalid prompt: {original_prompt}")
                    
                    processed_lines.append(json.dumps(data, ensure_ascii=False))
                    
                except json.JSONDecodeError as e:
                    print(f"Warning: Invalid JSON on line {line_num}: {e}")
                    # Keep the original line if it's not valid JSON
                    processed_lines.append(line)
    
    except Exception as e:
        print(f"Error reading file {input_path}: {e}")
        sys.exit(1)
    
    # Write the processed data to the output file
    try:
        with open(output_path, 'w', encoding='utf-8') as f:
            for line in processed_lines:
                f.write(line + '\n')
        
        print(f"Successfully processed {input_path} -> {output_path}")
        print(f"Modified {modified_count} lines out of {len(processed_lines)} total lines")
        
    except Exception as e:
        print(f"Error writing to file {output_path}: {e}")
        sys.exit(1)


def main():
    parser = argparse.ArgumentParser(
        description="Process JSONL file to standardize prompt format",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  python fix_to_same_format.py input.jsonl output.jsonl
  python fix_to_same_format.py /path/to/input.jsonl /path/to/output.jsonl
        """
    )
    
    parser.add_argument(
        'input_file',
        help='Path to the input JSONL file to process'
    )
    
    parser.add_argument(
        'output_file',
        help='Path to the output JSONL file to write the processed data'
    )
    
    args = parser.parse_args()
    
    # Process the file
    process_jsonl_file(args.input_file, args.output_file)


if __name__ == "__main__":
    main()
