"""
DeepScaleR 40K Dataset Generation Script - DeepSeek Version

This script processes the DeepScaleR dataset through DeepSeek API to generate multiple
chain-of-thought reasoning formats. It's designed to be clean, modular, and focused
specifically on DeepScaleR dataset processing with DeepSeek models.

Key Features:
- Parallel processing with configurable workers
- Individual file processing for fault tolerance
- Auto-resume capability from existing files
- Comprehensive error handling and logging
- Batch processing support
- Failed sample tracking and reprocessing
- DeepSeek API integration

Usage:
    python deepscaleR_40k_generation_deepseek.py --test                          # Test with 10 samples
    python deepscaleR_40k_generation_deepseek.py --full                          # Process full dataset
    python deepscaleR_40k_generation_deepseek.py --batch 0 10000                 # Process batch
    python deepscaleR_40k_generation_deepseek.py --workers 40                    # Use 40 workers (configurable)
    python deepscaleR_40k_generation_deepseek.py --no-parallel                   # Disable parallel processing
    python deepscaleR_40k_generation_deepseek.py --no-resume                     # Disable auto-resume
    
Configuration:
    Edit config.json to modify DeepSeek settings, processing parameters, and dataset options
    
Directory Structure:
    - Input data: deepscaleR_input_data/
    - Output data: deepscaleR_output_data/
    - Individual files: deepscaleR_output_data/individual_files_*/
"""

import json
import tiktoken
import time
import os
import sys
import logging
import requests
from typing import Dict, Any, Optional, List, Tuple
import concurrent.futures
from tqdm import tqdm
from utility_functions import safe_count_tokens

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Load configuration
config_path = os.path.join(os.path.dirname(__file__), "config.json")
try:
    with open(config_path, 'r') as f:
        config = json.load(f)
    
    # Validate required configuration
    if not config.get("deepseek_api_key"):
        raise ValueError("deepseek_api_key not found in config.json")
    
    # Extract configuration values with defaults
    DEEPSEEK_CONFIG = config.get("deepseek_model", {})
    PROCESSING_CONFIG = config.get("processing", {})
    DATASET_CONFIG = config.get("dataset", {})
    OUTPUT_CONFIG = config.get("output", {})
    
    # Set configuration constants
    MAX_WORKERS = PROCESSING_CONFIG.get("max_workers", 40)
    ENABLE_PARALLEL = PROCESSING_CONFIG.get("enable_parallel", True)
    ENABLE_INDIVIDUAL_FILES = PROCESSING_CONFIG.get("enable_individual_files", True)
    AUTO_RESUME = PROCESSING_CONFIG.get("auto_resume", True)
    DATASET_NAME = DATASET_CONFIG.get("name", "agentica-org/DeepScaleR-Preview-Dataset")
    
    # DeepSeek configuration
    DEEPSEEK_API_KEY = config.get("deepseek_api_key")
    DEEPSEEK_MODEL = DEEPSEEK_CONFIG.get("model", "deepseek-reasoner")
    DEEPSEEK_BASE_URL = DEEPSEEK_CONFIG.get("base_url", "https://api.model.hippocraticai.com/dev/ds671r1/v1")
    DEEPSEEK_TIMEOUT = DEEPSEEK_CONFIG.get("timeout", 60.0)
    DEEPSEEK_MAX_RETRIES = DEEPSEEK_CONFIG.get("max_retries", 3)
    DEEPSEEK_BASE_DELAY = DEEPSEEK_CONFIG.get("base_delay", 1.0)
    DEEPSEEK_MAX_TOKENS = DEEPSEEK_CONFIG.get("max_tokens", 4000)
    DEEPSEEK_TEMPERATURE = DEEPSEEK_CONFIG.get("temperature", 0.2)

    # Output configuration
    INDIVIDUAL_FILES_SUFFIX = OUTPUT_CONFIG.get("individual_files_suffix", "individual_files_deepscaler_")
    SUMMARY_FILE_PREFIX = OUTPUT_CONFIG.get("summary_file_prefix", "processing_summary_deepscaler_")
    INPUT_DATA_DIR = OUTPUT_CONFIG.get("input_data_dir", "deepscaleR_input_data")
    OUTPUT_DATA_DIR = OUTPUT_CONFIG.get("output_data_dir", "deepscaleR_output_data")
    
except FileNotFoundError:
    raise FileNotFoundError(f"config.json not found at {config_path}")
except json.JSONDecodeError:
    raise ValueError(f"Invalid JSON in {config_path}")

def ensure_directories_exist():
    """Ensure that input and output directories exist."""
    os.makedirs(INPUT_DATA_DIR, exist_ok=True)
    os.makedirs(OUTPUT_DATA_DIR, exist_ok=True)
    logger.info(f"📁 Input directory: {INPUT_DATA_DIR}")
    logger.info(f"📁 Output directory: {OUTPUT_DATA_DIR}")

# Ensure directories exist
ensure_directories_exist()

def build_structured_messages(question: str = "", final_answer: str = "") -> str:
    """
    Build structured prompt for DeepSeek API call.
    
    Args:
        question: The original question
        final_answer: The final answer
    
    Returns:
        Formatted prompt string for API call
    """
    system_prompt = """You are a mathematical reasoning and cognitive modeling assistant. Your job is to generate a detailed step-by-step solution, and then compress it into multiple formats for cognitive modeling. You are always given the final answer to guide your solution.

Please output all of the following:

1. **Long CoT** – A detailed natural-language explanation of all reasoning steps.
2. **Short CoT** – A concise version of the reasoning steps in plain English with all critical steps preserved.
3. **Mentalese CoT** – A symbolic, logic-based chain of thought using compact primitives and logical steps only. Each step should be separated by a semicolon (`;`) with **no natural language**. All core logic steps must be preserved.
4. **Difficulty Score** – An integer between 1 and 5, scored based on the rubric below.

---

#Difficulty Score Rubric

| Score | Description |
|-------|-------------|
| **1** | Very easy – Requires basic arithmetic, direct computation, or simple formula. |
| **2** | Easy – Involves one or two operations with basic algebra or logic, still straightforward. |
| **3** | Medium – Requires multi-step reasoning or applying formulas in a non-trivial way. |
| **4** | Hard – Involves casework, clever manipulations, or multiple concepts. |
| **5** | Very hard – Involves deep insights, uncommon methods, or significant abstraction. |

---

# Mentalese Format Rules

- Use **ONLY** symbolic notation in the form: `OP:params` or `OP:result`
- No natural language
- No redundancy
- Every step must be **necessary and sufficient** for solving the problem
- Example:  
  `SET:w;EQ:abs(180-5.5*w)=110;CASE1:180-5.5*w=110;SOLVE1:w=70/5.5;CASE2:180-5.5*w=-110;SOLVE2:w=290/5.5;CALC1:w1=70/5.5=12+8/11;CALC2:w2=290/5.5=52+8/11;DIFF:t=w2-w1=40;ANS:40`

---

Wrap the output in a single JSON object with the following keys:  
`{ "long_cot", "short_cot", "mentalese_cot", "difficulty_score" }`"""

    user_prompt = f"QUESTION: {question}\nANSWER: {final_answer}\n\nRESPONSE: "
    
    return f"{system_prompt}\n\n{user_prompt}"

def make_api_call_with_retry(prompt: str, max_retries: int = None, base_delay: float = None) -> Optional[Dict[str, Any]]:
    """
    Make DeepSeek API call with retry logic and exponential backoff.
    
    Args:
        prompt: Formatted prompt string for API call
        max_retries: Maximum number of retry attempts (uses config default if None)
        base_delay: Base delay for exponential backoff (uses config default if None)
    
    Returns:
        Parsed JSON response or None if failed
    """
    # Use config defaults if not provided
    if max_retries is None:
        max_retries = DEEPSEEK_MAX_RETRIES
    if base_delay is None:
        base_delay = DEEPSEEK_BASE_DELAY
    
    headers = {
        "x-llm-model-api-key": DEEPSEEK_API_KEY,
        "Content-Type": "application/json"
    }
    
    payload = {
        "model": DEEPSEEK_MODEL,
        "prompt": prompt,
        "max_tokens": DEEPSEEK_MAX_TOKENS,
        "temperature": DEEPSEEK_TEMPERATURE,
        "top_p": 1.0,
        "n": 1,
        "stream": False,
    }
    
    for attempt in range(max_retries):
        try:
            logger.info(f"Making DeepSeek API call, attempt {attempt + 1}/{max_retries}")
            
            response = requests.post(
                f"{DEEPSEEK_BASE_URL}/completions",
                headers=headers,
                json=payload,
                timeout=DEEPSEEK_TIMEOUT
            )
            
            if response.status_code != 200:
                raise Exception(f"API error {response.status_code}: {response.text}")
            
            response_data = response.json()
            content = response_data.get("choices", [{}])[0].get("text", "")
            
            if not content:
                raise ValueError("Empty response content")
            
            # Try to parse JSON response
            try:
                result = json.loads(content)
                logger.info("DeepSeek API call successful")
                return result
            except json.JSONDecodeError as json_error:
                logger.warning(f"JSON parsing error: {json_error}")
                # Try to extract JSON from content if it's wrapped in code blocks
                if "```json" in content:
                    json_start = content.find("```json") + 7
                    json_end = content.rfind("```")
                    if json_end > json_start:
                        json_content = content[json_start:json_end].strip()
                        try:
                            result = json.loads(json_content)
                            logger.info("Successfully extracted JSON from code blocks")
                            return result
                        except json.JSONDecodeError:
                            pass
                
                if attempt == max_retries - 1:
                    logger.error(f"Failed to parse JSON after all attempts: {content[:200]}...")
                    return None
                else:
                    logger.warning("Retrying due to JSON parsing error")
                    
        except Exception as e:
            logger.error(f"DeepSeek API call attempt {attempt + 1} failed: {str(e)}")
            if attempt == max_retries - 1:
                logger.error("Max retries reached, giving up on this request")
                return None
            
            # Exponential backoff
            delay = base_delay * (2 ** attempt)
            logger.info(f"Waiting {delay} seconds before retry...")
            time.sleep(delay)
    
    return None

def extract_json_fields(result: Dict[str, Any]) -> Dict[str, str]:
    """
    Process the API result and ensure all required fields are present.
    
    Args:
        result: Raw API response dictionary with fields
    
    Returns:
        Processed result with all required fields
    """
    required_fields = [
        "long_cot",
        "short_cot",
        "mentalese_cot",
        "difficulty_score"
    ]
    
    processed_result = {}
    for field in required_fields:
        processed_result[field] = result.get(field, "")
        
    return processed_result

def process_single_entry(args: Tuple[int, Dict[str, str]]) -> Optional[Dict[str, Any]]:
    """
    Process a single entry through the DeepSeek API.
    
    Args:
        args: Tuple containing (idx, entry)
            idx: Index of the entry
            entry: The data entry to process
    
    Returns:
        Dict with processed result or None if failed
    """
    try:
        idx, entry = args
        
        # Extract fields
        question = entry.get("question", "")
        final_answer = entry.get("answer", "")
        
        prompt = build_structured_messages(question, final_answer)
        result = make_api_call_with_retry(prompt)
        
        if result is None:
            logger.error(f"Entry {idx + 1}: Failed to get valid DeepSeek API response")
            return None
        
        # Process and validate result
        processed_result = extract_json_fields(result)
        
        # Build output entry
        output_entry = {
            "index": entry.get("sample_index", idx),
            "dataset_source": entry.get("dataset_source", "deepscaler"),
            "question": question,
            "final_answer": final_answer,
            "long_cot": processed_result["long_cot"],
            "long_cot_tokens": safe_count_tokens(processed_result["long_cot"]),
            "short_cot": processed_result["short_cot"],
            "short_cot_tokens": safe_count_tokens(processed_result["short_cot"]),
            "mentalese_cot": processed_result["mentalese_cot"],
            "mentalese_cot_tokens": safe_count_tokens(processed_result["mentalese_cot"]),
            "difficulty_score": processed_result["difficulty_score"]
        }
        
        logger.info(f"Entry {idx + 1}: Successfully processed")
        return {"idx": idx, "entry": output_entry, "success": True}
        
    except Exception as e:
        logger.error(f"Entry {idx + 1}: Unexpected error: {str(e)}")
        return {"idx": idx, "entry": None, "success": False, "error": str(e)}

def extract_deepscaler_data(dataset_name: str = DATASET_NAME, output_path: Optional[str] = None, explore_first: bool = True) -> List[Dict[str, str]]:
    """
    Extract data from DeepScaleR dataset.
    
    Args:
        dataset_name: HuggingFace dataset name
        output_path: Optional path to save extracted data (if None, uses default input directory)
        explore_first: Whether to explore dataset first
    
    Returns:
        List of extracted data dictionaries
    """
    # Set default output path if not provided
    if output_path is None:
        output_path = os.path.join(INPUT_DATA_DIR, "deepscaler_extracted_data.jsonl")
    try:
        from datasets import load_dataset
        
        logger.info(f"Loading DeepScaleR dataset: {dataset_name}")
        dataset = load_dataset(dataset_name)
        
        if explore_first:
            logger.info(f"Dataset info: {dataset}")
            logger.info(f"Available splits: {list(dataset.keys())}")
            logger.info(f"Sample columns: {dataset['train'].column_names if 'train' in dataset else 'No train split'}")
        
        # Use train split if available, otherwise use the first split
        split_name = 'train' if 'train' in dataset else list(dataset.keys())[0]
        split_data = dataset[split_name]
        
        logger.info(f"Using split: {split_name} with {len(split_data)} samples")
        
        extracted_data = []
        for idx, sample in enumerate(split_data):
            # Extract relevant fields based on DeepScaleR dataset structure
            extracted_sample = {
                "sample_index": idx,
                "dataset_source": "deepscaler",
                "question": sample.get("problem", ""),
                "answer": sample.get("answer", "")
            }
            extracted_data.append(extracted_sample)
        
        # Save extracted data if path provided
        if output_path:
            with open(output_path, 'w', encoding='utf-8') as f:
                for sample in extracted_data:
                    f.write(json.dumps(sample, ensure_ascii=False) + '\n')
            logger.info(f"Saved extracted data to: {output_path}")
        
        logger.info(f"Extracted {len(extracted_data)} samples from DeepScaleR dataset")
        return extracted_data
        
    except Exception as e:
        logger.error(f"Failed to extract DeepScaleR data: {e}")
        return []

def save_individual_result(result: Dict[str, Any], output_dir: str, idx: int) -> bool:
    """
    Save individual processing result to a separate file.
    
    Args:
        result: The processing result dictionary
        output_dir: Directory to save individual files
        idx: Index for the filename
    
    Returns:
        True if saved successfully, False otherwise
    """
    try:
        individual_file = os.path.join(output_dir, f"{idx}.jsonl")
        os.makedirs(output_dir, exist_ok=True)        
        with open(individual_file, 'w', encoding='utf-8') as f:
            json.dump(result, f, ensure_ascii=False)        
        return True
    except Exception as e:
        logger.error(f"Failed to save individual result {idx}: {e}")
        return False

def check_individual_result_exists(output_dir: str, idx: int) -> bool:
    """
    Check if individual result file already exists and is valid.
    
    Args:
        output_dir: Directory containing individual files
        idx: Index to check
    
    Returns:
        True if file exists and is valid, False otherwise
    """
    try:
        individual_file = os.path.join(output_dir, f"{idx}.jsonl")
        if not os.path.exists(individual_file):
            return False
        
        with open(individual_file, 'r', encoding='utf-8') as f:
            result = json.load(f)
            
        return (isinstance(result, dict) and 
                result.get("success", False) and 
                result.get("entry") is not None)
    except:
        return False

def process_single_entry_with_individual_save(args: Tuple[int, Dict[str, str], str, bool]) -> Dict[str, Any]:
    """
    Process single entry and save to individual file.
    
    Args:
        args: Tuple containing (idx, entry, output_dir, skip_if_exists)
    
    Returns:
        Dict with processing result and file save status
    """
    try:
        idx, entry, output_dir, skip_if_exists = args
        
        # Check if individual file already exists and skip if requested
        if skip_if_exists and check_individual_result_exists(output_dir, idx):
            logger.debug(f"⏭️  Skipping entry {idx + 1}: individual file already exists")
            return {"idx": idx, "success": True, "skipped": True, "reason": "file_exists"}
        
        # Process the entry
        result = process_single_entry((idx, entry))
        
        if result:
            # Save to individual file
            save_success = save_individual_result(result, output_dir, idx)
            
            if save_success:
                logger.debug(f"💾 Saved individual result for entry {idx + 1}")
                return {"idx": idx, "success": result.get("success", False), "saved": True}
            else:
                logger.warning(f"⚠️  Failed to save individual result for entry {idx + 1}")
                return {"idx": idx, "success": False, "saved": False, "error": "save_failed"}
        else:
            # Save failed result for tracking
            failed_result = {"idx": idx, "success": False, "entry": None, "error": "processing_failed"}
            save_individual_result(failed_result, output_dir, idx)
            return {"idx": idx, "success": False, "saved": True}
        
    except Exception as e:
        logger.error(f"Entry {idx + 1}: Unexpected error in enhanced processing: {str(e)}")
        error_result = {"idx": idx, "success": False, "entry": None, "error": str(e)}
        try:
            save_individual_result(error_result, args[2], args[0])  # output_dir, idx
        except:
            pass
        return {"idx": idx, "success": False, "saved": False, "error": str(e)}

def merge_individual_files(output_dir: str, final_output_path: str, run_id: str, extracted_data: List[Dict[str, str]]) -> Dict[str, int]:
    """
    Merge all individual result files into final output file.
    
    Args:
        output_dir: Directory containing individual files
        final_output_path: Path for final merged output
        run_id: Run identifier
        extracted_data: Original extracted data for reference
    
    Returns:
        Dictionary with merge statistics
    """
    logger.info(f"🔄 Merging individual files from {output_dir} to {final_output_path}")
    
    # Get all individual files
    individual_files = []
    if os.path.exists(output_dir):
        for filename in os.listdir(output_dir):
            if filename.endswith('.jsonl') and filename[:-6].isdigit():
                idx = int(filename[:-6])
                individual_files.append(idx)
    
    individual_files.sort()
    
    success_count = 0
    failed_count = 0
    
    # Merge files
    with open(final_output_path, 'w', encoding='utf-8') as outfile:
        for idx in tqdm(individual_files, desc="Merging files"):
            try:
                individual_file = os.path.join(output_dir, f"{idx}.jsonl")
                with open(individual_file, 'r', encoding='utf-8') as f:
                    result = json.load(f)
                
                if result and result.get("success", False) and result.get("entry"):
                    outfile.write(json.dumps(result["entry"], ensure_ascii=False) + "\n")
                    success_count += 1
                else:
                    failed_count += 1
            except Exception as e:
                logger.error(f"Failed to process individual file {idx}: {e}")
                failed_count += 1
    
    merge_stats = {
        "total_files": len(individual_files),
        "successful": success_count,
        "failed": failed_count,
        "final_output": final_output_path
    }
    
    logger.info(f"✅ Merge complete: {success_count} successful, {failed_count} failed")
    return merge_stats

def process_deepscaler_dataset(extracted_data: List[Dict[str, str]], output_path: str, run_id: str = None) -> Dict[str, Any]:
    """
    Process DeepScaleR dataset entries with individual file processing using DeepSeek.
    
    Args:
        extracted_data: List of extracted data dictionaries
        output_path: Path to save the final merged results
        run_id: Optional run identifier
    
    Returns:
        Processing summary dictionary
    """
    if not run_id:
        run_id = time.strftime("%Y%m%d_%H%M%S")
    
    # Create individual files directory
    output_dir = os.path.dirname(output_path) or "."
    individual_files_dir = os.path.join(output_dir, f"{INDIVIDUAL_FILES_SUFFIX}{run_id}")
    
    logger.info(f"🚀 Processing {len(extracted_data)} DeepScaleR dataset entries with DeepSeek")
    logger.info(f"📁 Individual files directory: {individual_files_dir}")
    logger.info(f"📄 Final output: {output_path}")
    logger.info(f"🔧 Run ID: {run_id}")
    logger.info(f"⚡ Workers: {MAX_WORKERS}")
    logger.info(f"🔄 Auto-resume: {AUTO_RESUME}")
    logger.info(f"🤖 Model: {DEEPSEEK_MODEL}")
    
    # Check existing files for resume capability
    existing_count = 0
    if AUTO_RESUME:
        for idx in range(len(extracted_data)):
            if check_individual_result_exists(individual_files_dir, idx):
                existing_count += 1
        
        if existing_count > 0:
            logger.info(f"🔄 Found {existing_count} existing files - will resume from where left off")
    
    # Prepare arguments for parallel processing
    args_list = [(idx, entry, individual_files_dir, AUTO_RESUME) 
                 for idx, entry in enumerate(extracted_data)]
    
    processed_count = 0
    error_count = 0
    skipped_count = 0
    processing_results = []
    
    # Process in parallel with progress bar
    with concurrent.futures.ThreadPoolExecutor(max_workers=MAX_WORKERS) as executor:
        future_to_idx = {executor.submit(process_single_entry_with_individual_save, args): args[0] 
                        for args in args_list}
        
        with tqdm(total=len(extracted_data), desc="Processing DeepScaleR entries") as pbar:
            for future in concurrent.futures.as_completed(future_to_idx):
                result = future.result()
                original_idx = future_to_idx[future]
                
                if result is not None:
                    processing_results.append(result)
                    
                    if result.get("skipped", False):
                        skipped_count += 1
                    elif result.get("success", False):
                        processed_count += 1
                    else:
                        error_count += 1
                        logger.warning(f"❌ Failed sample {original_idx}: {result.get('error', 'Unknown error')}")
                else:
                    error_count += 1
                
                pbar.update(1)
    
    # Merge individual files into final output
    logger.info("🔗 Starting merge of individual files...")
    merge_stats = merge_individual_files(individual_files_dir, output_path, run_id, extracted_data)
    
    # Generate comprehensive processing summary
    total_processed = processed_count + skipped_count
    actual_new_processing = processed_count
    
    summary = {
        "run_id": run_id,
        "dataset_type": "deepscaler",
        "model_used": DEEPSEEK_MODEL,
        "total_samples": len(extracted_data),
        "existing_files_found": existing_count,
        "new_processing_done": actual_new_processing,
        "skipped_existing": skipped_count,
        "failed_processing": error_count,
        "total_successful": merge_stats["successful"],
        "total_failed": merge_stats["failed"],
        "success_rate": (merge_stats["successful"] / len(extracted_data)) * 100 if extracted_data else 0,
        "individual_files_dir": individual_files_dir,
        "final_output_file": output_path,
        "merge_stats": merge_stats,
        "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    }
    
    summary_file = os.path.join(OUTPUT_DATA_DIR, f"{SUMMARY_FILE_PREFIX}{run_id}.json")
    with open(summary_file, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    
    logger.info("="*80)
    logger.info(f"🎉 DEEPSCALER DATASET PROCESSING COMPLETE WITH DEEPSEEK!")
    logger.info(f"📊 Total samples: {len(extracted_data)}")
    if existing_count > 0:
        logger.info(f"⏭️  Existing files: {existing_count}")
        logger.info(f"🆕 New processing: {actual_new_processing}")
    logger.info(f"✅ Total successful: {merge_stats['successful']}")
    logger.info(f"❌ Total failed: {merge_stats['failed']}")
    logger.info(f"📈 Success rate: {summary['success_rate']:.1f}%")
    logger.info(f"🤖 Model used: {DEEPSEEK_MODEL}")
    logger.info(f"📁 Final results: {output_path}")
    logger.info(f"🗂️  Individual files: {individual_files_dir}")
    logger.info(f"📋 Summary: {summary_file}")
    logger.info("="*80)
    
    return summary

def test_processing():
    """Test function to process only 10 examples."""
    logger.info("="*60)
    logger.info("🧪 STARTING DEEPSCALER DATASET TEST WITH DEEPSEEK - 10 EXAMPLES ONLY")
    logger.info("="*60)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=os.path.join(INPUT_DATA_DIR, "test_deepscaler_extracted_10_samples.jsonl"),
        explore_first=True
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    # Process only first 10 samples for testing
    test_data = extracted_data[:10]
    logger.info(f"🔬 Testing with {len(test_data)} samples using DeepSeek")
    
    # Process through DeepSeek API
    process_deepscaler_dataset(test_data, os.path.join(OUTPUT_DATA_DIR, "test_deepscaler_processed_10_samples.jsonl"))
    
    logger.info("="*60)
    logger.info("✅ TEST COMPLETE!")
    logger.info(f"📁 Results saved to: {os.path.join(OUTPUT_DATA_DIR, 'test_deepscaler_processed_10_samples.jsonl')}")
    logger.info("🚀 If satisfied, run full processing: python deepscaleR_40k_generation_deepseek.py --full")
    logger.info("="*60)

def full_processing():
    """Process the full DeepScaleR dataset."""
    logger.info("="*80)
    logger.info("🚀 STARTING FULL DEEPSCALER DATASET PROCESSING WITH DEEPSEEK")
    logger.info("⚠️  WARNING: This will process the ENTIRE dataset!")
    logger.info("💡 TIP: Run 'python deepscaleR_40k_generation_deepseek.py --test' first!")
    logger.info("="*80)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=os.path.join(INPUT_DATA_DIR, "deepscaler_extracted_data.jsonl"),
        explore_first=True
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    logger.info(f"📊 Processing {len(extracted_data)} samples with DeepSeek")
    
    # Process through DeepSeek API
    process_deepscaler_dataset(extracted_data, os.path.join(OUTPUT_DATA_DIR, "deepscaler_processed_output.jsonl"))

def batch_processing(start_idx: int = 0, batch_size: int = None):
    """Process DeepScaleR dataset in batches."""
    if batch_size is None:
        batch_size = DATASET_CONFIG.get("default_batch_size", 10000)
    end_idx = start_idx + batch_size
    
    logger.info("="*80)
    logger.info(f"🚀 STARTING DEEPSCALER DATASET BATCH PROCESSING WITH DEEPSEEK")
    logger.info(f"📊 Batch range: [{start_idx}:{end_idx}] ({batch_size} samples)")
    logger.info("="*80)
    
    # Extract data from DeepScaleR dataset
    extracted_data = extract_deepscaler_data(
        dataset_name=DATASET_NAME,
        output_path=None,
        explore_first=False
    )
    
    if not extracted_data:
        logger.error("Failed to extract data from DeepScaleR dataset")
        return
    
    logger.info(f"Total dataset size: {len(extracted_data)}")
    
    if start_idx >= len(extracted_data):
        logger.error(f"Start index {start_idx} is beyond dataset size {len(extracted_data)}")
        return
    
    # Adjust end_idx if it exceeds dataset size
    actual_end_idx = min(end_idx, len(extracted_data))
    actual_batch_size = actual_end_idx - start_idx
    
    logger.info(f"📊 Processing batch [{start_idx}:{actual_end_idx}] - {actual_batch_size} samples")
    
    # Generate output filename with batch info
    output_path = os.path.join(OUTPUT_DATA_DIR, f"deepscaler_processed_batch_{start_idx}_{actual_end_idx}.jsonl")
    
    # Process through DeepSeek API
    batch_data = extracted_data[start_idx:actual_end_idx]
    batch_run_id = f"deepscaler_batch_{start_idx}_{actual_end_idx}_{time.strftime('%Y%m%d_%H%M%S')}"
    process_deepscaler_dataset(batch_data, output_path, batch_run_id)
    
    logger.info(f"🔄 Next batch command: python deepscaleR_40k_generation_deepseek.py --batch {actual_end_idx}")

def main():
    """Main function to handle command line arguments."""
    global MAX_WORKERS, ENABLE_PARALLEL, AUTO_RESUME
    
    # Parse command line arguments
    if "--no-parallel" in sys.argv:
        ENABLE_PARALLEL = False
        sys.argv.remove("--no-parallel")
        logger.info("Parallel processing disabled")
    
    if "--workers" in sys.argv:
        try:
            worker_idx = sys.argv.index("--workers")
            MAX_WORKERS = int(sys.argv[worker_idx + 1])
            sys.argv.remove("--workers")
            sys.argv.remove(str(MAX_WORKERS))
            logger.info(f"Using {MAX_WORKERS} workers for parallel processing")
        except (ValueError, IndexError):
            logger.warning(f"Invalid --workers argument, using default from config: {MAX_WORKERS}")
    
    if "--no-resume" in sys.argv:
        AUTO_RESUME = False
        sys.argv.remove("--no-resume")
        logger.info("Auto-resume disabled - will reprocess all files")
    
    # Check command line arguments
    if len(sys.argv) > 1 and sys.argv[1] == "--test":
        test_processing()
    elif len(sys.argv) > 1 and sys.argv[1] == "--full":
        full_processing()
    elif len(sys.argv) > 1 and sys.argv[1] == "--batch":
        start_idx = int(sys.argv[2]) if len(sys.argv) > 2 else 0
        batch_size = int(sys.argv[3]) if len(sys.argv) > 3 else 10000
        batch_processing(start_idx, batch_size)
    else:
        logger.info("DeepScaleR 40K Dataset Generation Script - DeepSeek Version")
        logger.info("Usage:")
        logger.info("  python deepscaleR_40k_generation_deepseek.py --test                   # Test with 10 samples")
        logger.info("  python deepscaleR_40k_generation_deepseek.py --full                   # Process full dataset")
        logger.info(f"  python deepscaleR_40k_generation_deepseek.py --batch 0 {DATASET_CONFIG.get('default_batch_size', 10000)}  # Process batch")
        logger.info(f"  python deepscaleR_40k_generation_deepseek.py --workers {MAX_WORKERS}             # Use {MAX_WORKERS} workers (configurable)")
        logger.info("  python deepscaleR_40k_generation_deepseek.py --no-parallel            # Disable parallel processing")
        logger.info("  python deepscaleR_40k_generation_deepseek.py --no-resume              # Disable auto-resume")
        logger.info("")
        logger.info("Configuration: Edit config.json to modify DeepSeek settings, processing parameters, and dataset options")
        logger.info(f"🤖 Model: {DEEPSEEK_MODEL}")
        logger.info(f"📁 Input data directory: {INPUT_DATA_DIR}")
        logger.info(f"📁 Output data directory: {OUTPUT_DATA_DIR}")

if __name__ == "__main__":
    main() 