#!/usr/bin/env python3
"""
Script to create three JSONL files from deepscaler_processed_output.jsonl:
- deepscaler40k_mentalese_cot.jsonl (using mentalese_cot)
- deepscaler40k_long_cot.jsonl (using long_cot)
- deepscaler40k_short_cot.jsonl (using short_cot)

Each file will have the format:
{
    "instruction": str,
    "output": "<think>cot_content</think><answer>final_answer</answer>",
    "history": [],
    "input": ""
}
"""

import json
import os
from pathlib import Path


def convert_record(record, cot_type):
    """Convert a single record using the specified CoT type."""
    
    # Extract the question as instruction
    instruction = record.get("question", "")
    
    # Extract the appropriate CoT content
    thinking = record.get(cot_type, "")
    
    # Extract the final answer
    final_answer = record.get("final_answer", "")
    
    # Create the output field with <think> and <answer> tags
    output = f"<think>{thinking}</think><answer>{final_answer}</answer>"
    
    # Create the new record
    converted_record = {
        "instruction": instruction,
        "output": output,
        "history": [],
        "input": ""
    }
    
    return converted_record


def process_file(input_file, output_dir):
    """Process the input file and create three output files."""
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Define the CoT types and corresponding output files
    cot_configs = [
        ("mentalese_cot", "deepscaler40k_mentalese_cot.jsonl"),
        ("long_cot", "deepscaler40k_long_cot.jsonl"),
        ("short_cot", "deepscaler40k_short_cot.jsonl")
    ]
    
    # Open all output files
    output_files = {}
    for cot_type, filename in cot_configs:
        output_path = os.path.join(output_dir, filename)
        output_files[cot_type] = open(output_path, 'w', encoding='utf-8')
    
    try:
        converted_count = 0
        error_count = 0
        
        with open(input_file, 'r', encoding='utf-8') as infile:
            for line_num, line in enumerate(infile, 1):
                try:
                    line = line.strip()
                    if not line:
                        continue
                        
                    # Parse the JSON record
                    record = json.loads(line)
                    
                    # Convert and write to each output file
                    for cot_type, _ in cot_configs:
                        converted_record = convert_record(record, cot_type)
                        output_files[cot_type].write(
                            json.dumps(converted_record, ensure_ascii=False) + '\n'
                        )
                    
                    converted_count += 1
                    
                    # Progress indicator for large files
                    if converted_count % 1000 == 0:
                        print(f"Processed {converted_count} records...")
                        
                except json.JSONDecodeError as e:
                    print(f"Error parsing JSON on line {line_num}: {e}")
                    error_count += 1
                except Exception as e:
                    print(f"Error processing line {line_num}: {e}")
                    error_count += 1
    
    finally:
        # Close all output files
        for file_handle in output_files.values():
            file_handle.close()
    
    print(f"Conversion complete!")
    print(f"Successfully processed: {converted_count} records")
    if error_count > 0:
        print(f"Errors encountered: {error_count} records")
    
    # Print output file paths
    for cot_type, filename in cot_configs:
        output_path = os.path.join(output_dir, filename)
        print(f"Created: {output_path}")


def main():
    # Define file paths
    input_file = "USER_PATH/deepscaler_processed_output.jsonl"
    output_dir = "USER_PATH/formatted_data"
    
    # Check if input file exists
    if not os.path.exists(input_file):
        print(f"Error: Input file '{input_file}' does not exist.")
        return 1
    
    print(f"Processing '{input_file}'...")
    print(f"Output directory: '{output_dir}'")
    
    process_file(input_file, output_dir)
    
    return 0


if __name__ == "__main__":
    exit(main()) 