"""
DeepScaleR 40K Dataset Generation Script

This script processes the DeepScaleR dataset through GPT-4 API to generate multiple
chain-of-thought reasoning formats. It's designed to be clean, modular, and focused
specifically on DeepScaleR dataset processing.

Key Features:
- Parallel processing with configurable workers
- Individual file processing for fault tolerance
- Auto-resume capability from existing files
- Comprehensive error handling and logging
- Batch processing support
- Failed sample tracking and reprocessing

Usage:
    python deepscaleR_40k_generation.py --test                          # Test with 10 samples
    python deepscaleR_40k_generation.py --full                          # Process full dataset
    python deepscaleR_40k_generation.py --batch 0 10000                 # Process batch
    python deepscaleR_40k_generation.py --workers 40                    # Use 40 workers (configurable)
    python deepscaleR_40k_generation.py --no-parallel                   # Disable aparallel processing
    python deepscaleR_40k_generation.py --no-resume                     # Disable auto-resume
    
Configuration:
    Edit config.json to modify OpenAI settings, processing parameters, and dataset options
    
Directory Structure:
    - Input data: deepscaleR_input_data/
    - Output data: deepscaleR_output_data/
    - Individual files: deepscaleR_output_data/individual_files_*/
"""

import json
import tiktoken
import time
import os
import sys
import logging
from typing import Dict, Any, Optional, List, Tuple
import concurrent.futures
from tqdm import tqdm
from openai import OpenAI
from utility_functions import safe_count_tokens

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load configuration
config_path = os.path.join(os.path.dirname(__file__), "config.json")
try:
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Validate required configuration
    if not config.get("openai_api_key"):
        raise ValueError("openai_api_key not found in config.json")
    
    # Extract configuration values with defaults
    OPENAI_CONFIG = config.get("openai_gpt-4.1", {})
    OPENAI_CONFIG_O3 = config.get("openai_gpt-o3", {})
    PROCESSING_CONFIG = config.get("processing", {})
    DATASET_CONFIG = config.get("dataset", {})
    OUTPUT_CONFIG = config.get("output", {})
    
    # Set configuration constants
    MAX_WORKERS = PROCESSING_CONFIG.get("max_workers", 40)
    ENABLE_PARALLEL = PROCESSING_CONFIG.get("enable_parallel", True)
    ENABLE_INDIVIDUAL_FILES = PROCESSING_CONFIG.get("enable_individual_files", True)
    AUTO_RESUME = PROCESSING_CONFIG.get("auto_resume", True)
    DATASET_NAME = DATASET_CONFIG.get("name", "agentica-org/DeepScaleR-Preview-Dataset")
    
    # OpenAI configuration
    OPENAI_MODEL = OPENAI_CONFIG.get("model", "gpt-4.1")
    OPENAI_TEMPERATURE = OPENAI_CONFIG.get("temperature", 0.6)
    OPENAI_TIMEOUT = OPENAI_CONFIG.get("timeout", 60.0)
    OPENAI_MAX_TOKENS = OPENAI_CONFIG.get("max_tokens", 4000)
    OPENAI_MAX_RETRIES = OPENAI_CONFIG.get("max_retries", 3)
    OPENAI_BASE_DELAY = OPENAI_CONFIG.get("base_delay", 1.0)

    # Output configuration
    INDIVIDUAL_FILES_SUFFIX = OUTPUT_CONFIG.get("individual_files_suffix", "individual_files_deepscaler_")
    SUMMARY_FILE_PREFIX = OUTPUT_CONFIG.get("summary_file_prefix", "processing_summary_deepscaler_")
    INPUT_DATA_DIR = OUTPUT_CONFIG.get("input_data_dir", "deepscaleR_input_data")
    OUTPUT_DATA_DIR = OUTPUT_CONFIG.get("output_data_dir", "deepscaleR_output_data")
    
except FileNotFoundError:
    raise FileNotFoundError(f"config.json not found at {config_path}")
except json.JSONDecodeError:
    raise ValueError(f"Invalid JSON in {config_path}")

# Initialize OpenAI client
client = OpenAI(api_key=config["openai_api_key"])

def ensure_directories_exist():
    """Ensure that input and output directories exist."""
    os.makedirs(INPUT_DATA_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DATA_DIR, exist_ok=True)
    logger.info(f"📁 Input directory: {INPUT_DATA_DIR}")
    logger.info(f"📁 Output directory: {OUTPUT_DATA_DIR}")

# Ensure directories exist
ensure_directories_exist()

# def build_structured_messages(question: str = "", final_answer: str = "") -> List[Dict[str, str]]:
#     """
#     Build structured messages for GPT-4 API call.
    
#     Args:
#         question: The original question
#         final_answer: The final answer
    
#     Returns:
#         List of message dictionaries for API call
#     """
#     return [
#         {
#             "role": "system",
#             "content": "You are a maths problem solver as well as cognitive modeling assistant. Your job is to generate a long detailed step-by-step solution first and then into multiple compressed reasoning formats as defined below. I will also give you the final answer for the question that will guide you. Always ensure that the compressed versions retain all critical details needed for answering the question. If a step is necessary for correct logic or code, it must appear in those outputs, even if phrased briefly. Also, output a difficulty rating for the question based on the difficulty of the question.\n"
#                 "You must generate all of the following:\n\n"
#                 "1. Long CoT – Detailed step by step solution.\n"
#                 "2. Short CoT – A concise summary in plain natural language.\n"
#                 "3. Mentalese CoT – A compact, symbolic chain of thought using conceptual primitives and logical structure, not natural language. Focus on core inference steps, compressed reasoning, and internal representations without any redundancy\n"
#                 "4. Difficulty Level: Scores (Integer) between 1-5 where 5 means extremely tough and 1 means extremely easy\n"
#                 "Output everything in a single JSON object with the following keys:\n"
#                 "{long_cot, short_cot, mentalese_cot, difficulty_score}\n"
#         },
#         {
#             "role": "user",
#             "content": f"QUESTION: {question}\nANSWER: {final_answer}"
#         }
#     ]

def build_structured_messages(question: str = "", final_answer: str = "") -> List[Dict[str, str]]:
    """
    Build structured messages for GPT-4 API call.
    
    Args:
        question: The original question
        final_answer: The final answer
    
    Returns:
        List of message dictionaries for API call
    """
    return [
        {
            "role": "system",
            "content": """You are a mathematical reasoning and cognitive modeling assistant. Your job is to generate a detailed step-by-step solution, and then compress it into multiple formats for cognitive modeling. You are always given the final answer to guide your solution.

Please output all of the following:

1. **Long CoT** – A detailed natural-language explanation of all reasoning steps.
2. **Short CoT** – A concise version of the reasoning steps in plain English with all critical steps preserved.
3. **Mentalese CoT** – A symbolic, logic-based chain of thought using compact primitives and logical steps only. Each step should be separated by a semicolon (`;`) with **no natural language**. All core logic steps must be preserved.
4. **Difficulty Score** – An integer between 1 and 5, scored based on the rubric below.

---

#Difficulty Score Rubric

| Score | Description |
|-------|-------------|
| **1** | Very easy – Requires basic arithmetic, direct computation, or simple formula. |
| **2** | Easy – Involves one or two operations with basic algebra or logic, still straightforward. |
| **3** | Medium – Requires multi-step reasoning or applying formulas in a non-trivial way. |
| **4** | Hard – Involves casework, clever manipulations, or multiple concepts. |
| **5** | Very hard – Involves deep insights, uncommon methods, or significant abstraction. |

---

# Mentalese Format Rules

- Use **ONLY** symbolic notation in the form: `OP:params` or `OP:result`
- No natural language
- No redundancy
- Every step must be **necessary and sufficient** for solving the problem
- Example:  
  `SET:w;EQ:abs(180-5.5*w)=110;CASE1:180-5.5*w=110;SOLVE1:w=70/5.5;CASE2:180-5.5*w=-110;SOLVE2:w=290/5.5;CALC1:w1=70/5.5=12+8/11;CALC2:w2=290/5.5=52+8/11;DIFF:t=w2-w1=40;ANS:40`

---

Wrap the output in a single JSON object with the following keys:  
`{ "long_cot", "short_cot", "mentalese_cot", "difficulty_score" }`"""
        },
        {
            "role": "user",
            "content": f"QUESTION: {question}\nANSWER: {final_answer}\n\nRESPONSE: "
        }
    ]


def make_api_call_with_retry(messages: List[Dict[str, str]], max_retries: int = None, base_delay: float = None) -> Optional[Dict[str, Any]]:
    """
    Make API call with retry logic and exponential backoff.
    
    Args:
        messages: List of message dictionaries for API call
        max_retries: Maximum number of retry attempts (uses config default if None)
        base_delay: Base delay for exponential backoff (uses config default if None)
    
    Returns:
        Parsed JSON response or None if failed
    """
    # Use config defaults if not provided
    if max_retries is None:
        max_retries = OPENAI_MAX_RETRIES
    if base_delay is None:
        base_delay = OPENAI_BASE_DELAY
    
    for attempt in range(max_retries):
        try:
            logger.info(f"Making API call, attempt {attempt + 1}/{max_retries}")
            response = client.chat.completions.create(
                model=OPENAI_MODEL,
                messages=messages,
                temperature=OPENAI_TEMPERATURE,
                timeout=OPENAI_TIMEOUT,
                max_tokens=OPENAI_MAX_TOKENS
            )
            # response = client.chat.completions.create(
            #     model=OPENAI_MODEL,
            #     messages=messages,
            #     timeout=OPENAI_TIMEOUT,
            #     reasoning_effort="medium"
            # )
            
            content = response.choices[0].message.content
            if not content:
                raise ValueError("Empty response content")
            
            # Try to parse JSON response
            try:
                result = json.loads(content)
                logger.info("API call successful")
                return result
            except json.JSONDecodeError as json_error:
                logger.warning(f"JSON parsing error: {json_error}")
                # Try to extract JSON from content if it's wrapped in code blocks
                if "```json" in content:
                    json_start = content.find("```json") + 7
                    json_end = content.rfind("```")
                    if json_end > json_start:
                        json_content = content[json_start:json_end].strip()
                        try:
                            result = json.loads(json_content)
                            logger.info("Successfully extracted JSON from code blocks")
                            return result
                        except json.JSONDecodeError:
                            pass
                
                if attempt == max_retries - 1:
                    logger.error(f"Failed to parse JSON after all attempts: {content[:200]}...")
                    return None
                else:
                    logger.warning("Retrying due to JSON parsing error")
                    
        except Exception as e:
            logger.error(f"API call attempt {attempt + 1} failed: {str(e)}")
            if attempt == max_retries - 1:
                logger.error("Max retries reached, giving up on this request")
                return None
            
            # Exponential backoff
            delay = base_delay * (2 ** attempt)
            logger.info(f"Waiting {delay} seconds before retry...")
            time.sleep(delay)
    
    return None


def extract_json_fields(result: Dict[str, Any]) -> Dict[str, str]:
    """
    Process the API result and ensure all required fields are present.
    
    Args:
        result: Raw API response dictionary with fields
    
    Returns:
        Processed result with all required fields
    """
    required_fields = [
        "long_cot",
        "short_cot",
        "mentalese_cot",
        "difficulty_score"
    ]
    
    processed_result = {}
    for field in required_fields:
        processed_result[field] = result.get(field, "")
        
    return processed_result


def process_single_entry(args: Tuple[int, Dict[str, str]]) -> Optional[Dict[str, Any]]:
    """
    Process a single entry through the API.
    
    Args:
        args: Tuple containing (idx, entry)
            idx: Index of the entry
            entry: The data entry to process
    
    Returns:
        Dict with processed result or None if failed
    """
    try:
        idx, entry = args
        
        # Extract fields
        question = entry.get("question", "")
        final_answer = entry.get("answer", "")
        
        messages = build_structured_messages(question, final_answer)
        result = make_api_call_with_retry(messages)
        
        if result is None:
            logger.error(f"Entry {idx + 1}: Failed to get valid API response")
            return None
        
        # Process and validate result
        processed_result = extract_json_fields(result)
        
        # Build output entry
        output_entry = {
            "index": entry.get("sample_index", idx),
            "dataset_source": entry.get("dataset_source", "deepscaler"),
            "question": question,
            "final_answer": final_answer,
            "long_cot": processed_result["long_cot"],
            "long_cot_tokens": safe_count_tokens(processed_result["long_cot"]),
            "short_cot": processed_result["short_cot"],
            "short_cot_tokens": safe_count_tokens(processed_result["short_cot"]),
            "mentalese_cot": processed_result["mentalese_cot"],
            "mentalese_cot_tokens": safe_count_tokens(processed_result["mentalese_cot"]),
            "difficulty_score": processed_result["difficulty_score"]
        }
        
        logger.info(f"Entry {idx + 1}: Successfully processed")
        return {"idx": idx, "entry": output_entry, "success": True}
        
    except Exception as e:
        logger.error(f"Entry {idx + 1}: Unexpected error: {str(e)}")
        return {"idx": idx, "entry": None, "success": False, "error": str(e)}


def extract_deepscaler_data(dataset_name: str = DATASET_NAME, output_path: Optional[str] = None, explore_first: bool = True) -> List[Dict[str, str]]:
    """
    Extract data from DeepScaleR dataset.
    
    Args:
        dataset_name: HuggingFace dataset name
        output_path: Optional path to save extracted data (if None, uses default input directory)
        explore_first: Whether to explore dataset first
    
    Returns:
        List of extracted data dictionaries
    """
    # Set default output path if not provided
    if output_path is None:
        output_path = os.path.join(INPUT_DATA_DIR, "deepscaler_extracted_data.jsonl")
    try:
        from datasets import load_dataset
        
        logger.info(f"Loading DeepScaleR dataset: {dataset_name}")
        dataset = load_dataset(dataset_name)
        
        if explore_first:
            logger.info(f"Dataset info: {dataset}")
            logger.info(f"Available splits: {list(dataset.keys())}")
            logger.info(f"Sample columns: {dataset['train'].column_names if 'train' in dataset else 'No train split'}")
        
        # Use train split if available, otherwise use the first split
        split_name = 'train' if 'train' in dataset else list(dataset.keys())[0]
        split_data = dataset[split_name]
        
        logger.info(f"Using split: {split_name} with {len(split_data)} samples")
        
        extracted_data = []
        for idx, sample in enumerate(split_data):
            # Extract relevant fields based on DeepScaleR dataset structure
            extracted_sample = {
                "sample_index": idx,
                "dataset_source": "deepscaler",
                "question": sample.get("problem", ""),
                "answer": sample.get("answer", "")
            }
            extracted_data.append(extracted_sample)
        
        # Save extracted data if path provided
        if output_path:
            with open(output_path, 'w', encoding='utf-8') as f:
                for sample in extracted_data:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
            logger.info(f"Saved extracted data to: {output_path}")
        
        logger.info(f"Extracted {len(extracted_data)} samples from DeepScaleR dataset")
        return extracted_data
        
    except Exception as e:
        logger.error(f"Failed to extract DeepScaleR data: {e}")
        return []


def save_individual_result(result: Dict[str, Any], output_dir: str, idx: int) -> bool:
    """
    Save individual processing result to a separate file.
    
    Args:
        result: The processing result dictionary
        output_dir: Directory to save individual files
        idx: Index for the filename
    
    Returns:
        True if saved successfully, False otherwise
    """
    try:
        individual_file = os.path.join(output_dir, f"{idx}.jsonl")
        os.makedirs(output_dir, exist_ok=True)        
        with open(individual_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False)        
        return True
    except Exception as e:
        logger.error(f"Failed to save individual result {idx}: {e}")
        return False


def check_individual_result_exists(output_dir: str, idx: int) -> bool:
    """
    Check if individual result file already exists and is valid.
    
    Args:
        output_dir: Directory containing individual files
        idx: Index to check
    
    Returns:
        True if file exists and is valid, False otherwise
    """
    try:
        individual_file = os.path.join(output_dir, f"{idx}.jsonl")
        if not os.path.exists(individual_file):
            return False
        
        with open(individual_file, 'r', encoding='utf-8') as f:
            result = json.load(f)
            
        return (isinstance(result, dict) and 
                result.get("success", False) and 
                result.get("entry") is not None)
    except:
        return False


def process_single_entry_with_individual_save(args: Tuple[int, Dict[str, str], str, bool]) -> Dict[str, Any]:
    """
    Process single entry and save to individual file.
    
    Args:
        args: Tuple containing (idx, entry, output_dir, skip_if_exists)
    
    Returns:
        Dict with processing result and file save status
    """
    try:
        idx, entry, output_dir, skip_if_exists = args
        
        # Check if individual file already exists and skip if requested
        if skip_if_exists and check_individual_result_exists(output_dir, idx):
            logger.debug(f"⏭️  Skipping entry {idx + 1}: individual file already exists")
            return {"idx": idx, "success": True, "skipped": True, "reason": "file_exists"}
        
        # Process the entry
        result = process_single_entry((idx, entry))
        
        if result:
            # Save to individual file
            save_success = save_individual_result(result, output_dir, idx)
            
            if save_success:
                logger.debug(f"💾 Saved individual result for entry {idx + 1}")
                return {"idx": idx, "success": result.get("success", False), "saved": True}
            else:
                logger.warning(f"⚠️  Failed to save individual result for entry {idx + 1}")
                return {"idx": idx, "success": False, "saved": False, "error": "save_failed"}
        else:
            # Save failed result for tracking
            failed_result = {"idx": idx, "success": False, "entry": None, "error": "processing_failed"}
            save_individual_result(failed_result, output_dir, idx)
            return {"idx": idx, "success": False, "saved": True}
        
    except Exception as e:
        logger.error(f"Entry {idx + 1}: Unexpected error in enhanced processing: {str(e)}")
        error_result = {"idx": idx, "success": False, "entry": None, "error": str(e)}
        try:
            save_individual_result(error_result, args[2], args[0])  # output_dir, idx
        except:
            pass
        return {"idx": idx, "success": False, "saved": False, "error": str(e)}


def merge_individual_files(output_dir: str, final_output_path: str, run_id: str, extracted_data: List[Dict[str, str]]) -> Dict[str, int]:
    """
    Merge all individual result files into final output file.
    
    Args:
        output_dir: Directory containing individual files
        final_output_path: Path for final merged output
        run_id: Run identifier
        extracted_data: Original extracted data for reference
    
    Returns:
        Dictionary with merge statistics
    """
    logger.info(f"🔄 Merging individual files from {output_dir} to {final_output_path}")
    
    # Get all individual files
    individual_files = []
    if os.path.exists(output_dir):
        for filename in os.listdir(output_dir):
            if filename.endswith('.jsonl') and filename[:-6].isdigit():
                idx = int(filename[:-6])
                individual_files.append(idx)
    
    individual_files.sort()
    
    success_count = 0
    failed_count = 0
    
    # Merge files
    with open(final_output_path, 'w', encoding='utf-8') as outfile:
        for idx in tqdm(individual_files, desc="Merging files"):
            try:
                individual_file = os.path.join(output_dir, f"{idx}.jsonl")
                with open(individual_file, 'r', encoding='utf-8') as f:
                    result = json.load(f)
                
                if result and result.get("success", False) and result.get("entry"):
                    outfile.write(json.dumps(result["entry"], ensure_ascii=False) + "\n")
                    success_count += 1
                else:
                    failed_count += 1
            except Exception as e:
                logger.error(f"Failed to process individual file {idx}: {e}")
                failed_count += 1
    
    merge_stats = {
        "total_files": len(individual_files),
        "successful": success_count,
        "failed": failed_count,
        "final_output": final_output_path
    }
    
    logger.info(f"✅ Merge complete: {success_count} successful, {failed_count} failed")
    return merge_stats


def process_deepscaler_dataset(extracted_data: List[Dict[str, str]], output_path: str, run_id: str = None) -> Dict[str, Any]:
    """
    Process DeepScaleR dataset entries with individual file processing.
    
    Args:
        extracted_data: List of extracted data dictionaries
        output_path: Path to save the final merged results
        run_id: Optional run identifier
    
    Returns:
        Processing summary dictionary
    """
    if not run_id:
        run_id = time.strftime("%Y%m%d_%H%M%S")
    
    # Create individual files directory
    output_dir = os.path.dirname(output_path) or "."
    individual_files_dir = os.path.join(output_dir, f"{INDIVIDUAL_FILES_SUFFIX}{run_id}")
    
    logger.info(f"🚀 Processing {len(extracted_data)} DeepScaleR dataset entries")
    logger.info(f"📁 Individual files directory: {individual_files_dir}")
    logger.info(f"📄 Final output: {output_path}")
    logger.info(f"🔧 Run ID: {run_id}")
    logger.info(f"⚡ Workers: {MAX_WORKERS}")
    logger.info(f"🔄 Auto-resume: {AUTO_RESUME}")
    
    # Check existing files for resume capability
    existing_count = 0
    if AUTO_RESUME:
        for idx in range(len(extracted_data)):
            if check_individual_result_exists(individual_files_dir, idx):
                existing_count += 1
        
        if existing_count > 0:
            logger.info(f"🔄 Found {existing_count} existing files - will resume from where left off")
    
    # Prepare arguments for parallel processing
    args_list = [(idx, entry, individual_files_dir, AUTO_RESUME) 
                 for idx, entry in enumerate(extracted_data)]
    
    processed_count = 0
    error_count = 0
    skipped_count = 0
    processing_results = []
    
    # Process in parallel with progress bar
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_idx = {executor.submit(process_single_entry_with_individual_save, args): args[0] 
                        for args in args_list}
        
        with tqdm(total=len(extracted_data), desc="Processing DeepScaleR entries") as pbar:
            for future in concurrent.futures.as_completed(future_to_idx):
                result = future.result()
                original_idx = future_to_idx[future]
                
                if result is not None:
                    processing_results.append(result)
                    
                    if result.get("skipped", False):
                        skipped_count += 1
                    elif result.get("success", False):
                        processed_count += 1
                    else:
                        error_count += 1
                        logger.warning(f"❌ Failed sample {original_idx}: {result.get('error', 'Unknown error')}")
                else:
                    error_count += 1
                
                pbar.update(1)
    
    # Merge individual files into final output
    logger.info("🔗 Starting merge of individual files...")
    merge_stats = merge_individual_files(individual_files_dir, output_path, run_id, extracted_data)
    
    # Generate comprehensive processing summary
    total_processed = processed_count + skipped_count
    actual_new_processing = processed_count
    
    summary = {
        "run_id": run_id,
        "dataset_type": "deepscaler",
        "total_samples": len(extracted_data),
        "existing_files_found": existing_count,
        "new_processing_done": actual_new_processing,
        "skipped_existing": skipped_count,
        "failed_processing": error_count,
        "total_successful": merge_stats["successful"],
        "total_failed": merge_stats["failed"],
        "success_rate": (merge_stats["successful"] / len(extracted_data)) * 100 if extracted_data else 0,
        "individual_files_dir": individual_files_dir,
        "final_output_file": output_path,
        "merge_stats": merge_stats,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    
    summary_file = os.path.join(OUTPUT_DATA_DIR, f"{SUMMARY_FILE_PREFIX}{run_id}.json")
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    
    logger.info("="*80)
    logger.info(f"🎉 DEEPSCALER DATASET PROCESSING COMPLETE!")
    logger.info(f"📊 Total samples: {len(extracted_data)}")
    if existing_count > 0:
        logger.info(f"⏭️  Existing files: {existing_count}")
        logger.info(f"🆕 New processing: {actual_new_processing}")
    logger.info(f"✅ Total successful: {merge_stats['successful']}")
    logger.info(f"❌ Total failed: {merge_stats['failed']}")
    logger.info(f"📈 Success rate: {summary['success_rate']:.1f}%")
    logger.info(f"📁 Final results: {output_path}")
    logger.info(f"🗂️  Individual files: {individual_files_dir}")
    logger.info(f"📋 Summary: {summary_file}")
    logger.info("="*80)
    
    return summary


def test_processing():
    """Test function to process only 10 examples."""
    logger.info("="*60)
    logger.info("🧪 STARTING DEEPSCALER DATASET TEST - 10 EXAMPLES ONLY")
    logger.info("="*60)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=os.path.join(INPUT_DATA_DIR, "test_deepscaler_extracted_10_samples.jsonl"),
        explore_first=True
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    # Process only first 10 samples for testing
    test_data = extracted_data[:10]
    logger.info(f"🔬 Testing with {len(test_data)} samples")
    
    
    # Process through GPT-4 API
    process_deepscaler_dataset(test_data, os.path.join(OUTPUT_DATA_DIR, "test_deepscaler_processed_10_samples.jsonl"))
    
    logger.info("="*60)
    logger.info("✅ TEST COMPLETE!")
    logger.info(f"📁 Results saved to: {os.path.join(OUTPUT_DATA_DIR, 'test_deepscaler_processed_10_samples.jsonl')}")
    logger.info("🚀 If satisfied, run full processing: python deepscaleR_40k_generation.py --full")
    logger.info("="*60)


def full_processing():
    """Process the full DeepScaleR dataset."""
    logger.info("="*80)
    logger.info("🚀 STARTING FULL DEEPSCALER DATASET PROCESSING")
    logger.info("⚠️  WARNING: This will process the ENTIRE dataset!")
    logger.info("💡 TIP: Run 'python deepscaleR_40k_generation.py --test' first!")
    logger.info("="*80)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=os.path.join(INPUT_DATA_DIR, "deepscaler_extracted_data.jsonl"),
        explore_first=True
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    logger.info(f"📊 Processing {len(extracted_data)} samples")
    
    # Process through GPT-4 API
    process_deepscaler_dataset(extracted_data, os.path.join(OUTPUT_DATA_DIR, "deepscaler_processed_output.jsonl"))


def batch_processing(start_idx: int = 0, batch_size: int = None):
    """Process DeepScaleR dataset in batches."""
    if batch_size is None:
        batch_size = DATASET_CONFIG.get("default_batch_size", 10000)
    end_idx = start_idx + batch_size
    
    logger.info("="*80)
    logger.info(f"🚀 STARTING DEEPSCALER DATASET BATCH PROCESSING")
    logger.info(f"📊 Batch range: [{start_idx}:{end_idx}] ({batch_size} samples)")
    logger.info("="*80)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=None,
        explore_first=False
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    logger.info(f"Total dataset size: {len(extracted_data)}")
    
    if start_idx >= len(extracted_data):
        logger.error(f"Start index {start_idx} is beyond dataset size {len(extracted_data)}")
        return
    
    # Adjust end_idx if it exceeds dataset size
    actual_end_idx = min(end_idx, len(extracted_data))
    actual_batch_size = actual_end_idx - start_idx
    
    logger.info(f"📊 Processing batch [{start_idx}:{actual_end_idx}] - {actual_batch_size} samples")
    
    # Generate output filename with batch info
    output_path = os.path.join(OUTPUT_DATA_DIR, f"deepscaler_processed_batch_{start_idx}_{actual_end_idx}.jsonl")
    
    # Process through GPT-4 API
    batch_data = extracted_data[start_idx:actual_end_idx]
    batch_run_id = f"deepscaler_batch_{start_idx}_{actual_end_idx}_{time.strftime('%Y%m%d_%H%M%S')}"
    process_deepscaler_dataset(batch_data, output_path, batch_run_id)
    
    logger.info(f"🔄 Next batch command: python deepscaleR_40k_generation.py --batch {actual_end_idx}")


def main():
    """Main function to handle command line arguments."""
    global MAX_WORKERS, ENABLE_PARALLEL, AUTO_RESUME
    
    # Parse command line arguments
    if "--no-parallel" in sys.argv:
        ENABLE_PARALLEL = False
        sys.argv.remove("--no-parallel")
        logger.info("Parallel processing disabled")
    
    if "--workers" in sys.argv:
        try:
            worker_idx = sys.argv.index("--workers")
            MAX_WORKERS = int(sys.argv[worker_idx + 1])
            sys.argv.remove("--workers")
            sys.argv.remove(str(MAX_WORKERS))
            logger.info(f"Using {MAX_WORKERS} workers for parallel processing")
        except (ValueError, IndexError):
            logger.warning(f"Invalid --workers argument, using default from config: {MAX_WORKERS}")
    
    if "--no-resume" in sys.argv:
        AUTO_RESUME = False
        sys.argv.remove("--no-resume")
        logger.info("Auto-resume disabled - will reprocess all files")
    
    # Check command line arguments
    if len(sys.argv) > 1 and sys.argv[1] == "--test":
        test_processing()
    elif len(sys.argv) > 1 and sys.argv[1] == "--full":
        full_processing()
    elif len(sys.argv) > 1 and sys.argv[1] == "--batch":
        start_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
        batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 10000
        batch_processing(start_idx, batch_size)
    else:
        logger.info("DeepScaleR 40K Dataset Generation Script")
        logger.info("Usage:")
        logger.info("  python deepscaleR_40k_generation.py --test                   # Test with 10 samples")
        logger.info("  python deepscaleR_40k_generation.py --full                   # Process full dataset")
        logger.info(f"  python deepscaleR_40k_generation.py --batch 0 {DATASET_CONFIG.get('default_batch_size', 10000)}  # Process batch")
        logger.info(f"  python deepscaleR_40k_generation.py --workers {MAX_WORKERS}             # Use {MAX_WORKERS} workers (configurable)")
        logger.info("  python deepscaleR_40k_generation.py --no-parallel            # Disable parallel processing")
        logger.info("  python deepscaleR_40k_generation.py --no-resume              # Disable auto-resume")
        logger.info("")
        logger.info("Configuration: Edit config.json to modify OpenAI settings, processing parameters, and dataset options")
        logger.info(f"📁 Input data directory: {INPUT_DATA_DIR}")
        logger.info(f"📁 Output data directory: {OUTPUT_DATA_DIR}")


if __name__ == "__main__":
    main() 