#!/usr/bin/env python3
"""
Script to generate samples for the SmolTraces-R1 dataset with concurrent processing.
This generates reasoning traces from R1 and properly extracts the (Question, Thinking, Answer) triples,
ensuring the boxed answer is extracted properly without the \boxed{} notation.

This version includes enhanced error handling for DeepSeek API JSON errors and saves samples
to a single dataset file rather than individual files for better scalability.
"""

import os
import sys
import json
import time
import random
import logging
import re
import argparse
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from dotenv import load_dotenv
from openai import OpenAI

# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(sys.stdout),
        logging.FileHandler("st_dataset_generation.log")
    ]
)

# Define our own prompt template with proper escaping and think tags
R1_REASONING_PROMPT = """I want you to solve the following problem step-by-step, showing your reasoning process. Make sure to verify your work and correct any mistakes you find.

Question: {0}

Put all your reasoning inside <think>...</think> tags. After you've completed your thinking, present your final answer (and only your final answer) using LaTeX format with \\boxed{{}} notation outside the think tags.

For example:
<think>
This is where you'd show your work, step by step.
</think>

The answer is \\boxed{{x=5}}

Let me think through this carefully.
"""

def load_env_keys():
    """Load API keys from environment variables"""
    load_dotenv()
    
    keys = {
        "openai": os.environ.get("OPENAI_API_KEY", ""),
        "anthropic": os.environ.get("ANTHROPIC_API_KEY", ""),
        "mistral": os.environ.get("MISTRAL_API_KEY", ""),
        "deepseek": os.environ.get("DEEPSEEK_API_KEY", ""),
    }
    
    return keys

def extract_final_answer(text):
    """
    Extract the final answer from a text following standard patterns.
    This is a fallback if no boxed answer is found.
    """
    if not text:
        return ""
    
    # Look for common answer patterns
    patterns = [
        r"the answer is\s+(.+)",
        r"final answer:?\s+(.+)",
        r"the result is\s+(.+)",
        r"therefore,\s+(.+)",
        r"thus,\s+(.+)",
        r"conclusion:?\s+(.+)"
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    # If no patterns match, return the entire text as a fallback
    return text.strip()

def r1_api_call(question: str, api_key: str, temperature: float = 0.2, max_tokens: int = 4096):
    """
    Call the DeepSeek R1 API to generate a reasoning trace and answer.
    
    Args:
        question: The question to solve
        api_key: DeepSeek API key
        temperature: Temperature for generation
        max_tokens: Maximum tokens to generate
        
    Returns:
        Tuple of (reasoning_trace, answer)
    """
    prompt = R1_REASONING_PROMPT.format(question)
    
    # Create the OpenAI client with DeepSeek base URL and appropriate timeout
    client = OpenAI(
        api_key=api_key, 
        base_url="https://api.deepseek.com",
        timeout=1800.0  # Set a 1800-second timeout
    )
    
    MAX_ATTEMPTS = 5  # Increased max attempts
    SLEEP_TIME = 20   # Increased base sleep time
    
    for attempt in range(MAX_ATTEMPTS):
        try:
            logging.info(f"Making API call attempt {attempt+1}/{MAX_ATTEMPTS} for: {question[:50]}...")
            start_time = time.time()
            
            # Make the API call with explicit parameters to ensure clarity
            response = client.chat.completions.create(
                model="deepseek-reasoner",  # DeepSeek-R1 is accessed via "deepseek-reasoner"
                messages=[{"role": "user", "content": prompt}],
                temperature=temperature,
                max_tokens=max_tokens,
                stream=False,
                timeout=1800.0  # 30-minute timeout
            )
            
            elapsed_time = time.time() - start_time
            logging.info(f"API response received in {elapsed_time:.2f} seconds for: {question[:50]}...")
            
            # Extract the response text
            full_text = response.choices[0].message.content
            
            # Try to extract thinking tags
            think_pattern = re.compile(r'<think>(.*?)</think>', re.DOTALL)
            think_match = think_pattern.search(full_text)
            
            if think_match:
                thinking = think_match.group(1).strip()
                # Get everything after the last </think> tag
                answer_text = full_text.split('</think>')[-1].strip()
                logging.info(f"Successfully extracted thinking and answer using tags for: {question[:50]}...")
            else:
                # Fall back to traditional splitting if think tags aren't found
                parts = full_text.split("Final Answer:")
                
                if len(parts) >= 2:
                    thinking = parts[0].strip()
                    answer_text = "Final Answer: " + parts[1].strip()
                    logging.info(f"Successfully extracted thinking and answer using 'Final Answer:' for: {question[:50]}...")
                else:
                    # If no explicit "Final Answer:" marker, try other potential patterns
                    answer_patterns = [
                        "Therefore, the answer is",
                        "The answer is",
                        "In conclusion",
                        "So, the final result is"
                    ]
                    
                    # Find the last occurrence of any answer pattern
                    last_idx = -1
                    last_pattern = None
                    
                    for pattern in answer_patterns:
                        idx = full_text.rfind(pattern)
                        if idx > last_idx:
                            last_idx = idx
                            last_pattern = pattern
                    
                    if last_idx > 0:
                        thinking = full_text[:last_idx].strip()
                        answer_text = full_text[last_idx:].strip()
                        logging.info(f"Found alternative answer pattern '{last_pattern}' for: {question[:50]}...")
                    else:
                        # If still no clear separation, use the whole text as reasoning
                        thinking = full_text
                        answer_text = ""
                        logging.info(f"No clear answer pattern found for: {question[:50]}...")
            
            return thinking, answer_text
            
        except json.JSONDecodeError as je:
            # Specific handler for JSON decode errors from the DeepSeek API
            logging.error(f"JSON decode error in API response for '{question[:50]}...': {str(je)}")
            if attempt < MAX_ATTEMPTS - 1:
                # Use a longer backoff for JSON errors (known API issue)
                sleep_time = SLEEP_TIME * (2 ** attempt)  # Exponential backoff
                logging.info(f"DeepSeek API returned malformed JSON. Retrying in {sleep_time} seconds...")
                time.sleep(sleep_time)
        except Exception as e:
            logging.error(f"API call attempt {attempt+1}/{MAX_ATTEMPTS} failed for '{question[:50]}...': {str(e)}")
            if attempt < MAX_ATTEMPTS - 1:
                sleep_time = SLEEP_TIME * (2 ** attempt)  # Exponential backoff
                logging.info(f"Retrying in {sleep_time} seconds...")
                time.sleep(sleep_time)
    
    raise Exception(f"Failed to get response after {MAX_ATTEMPTS} attempts for '{question[:50]}...'")

def extract_boxed_answer(text):
    """
    Extract the answer from a boxed LaTeX format and strip the \boxed{} notation.
    
    Args:
        text: The text containing potential boxed answers
        
    Returns:
        The content inside the boxed answer without the \boxed{} notation
    """
    # Look for boxed content with a more robust regex
    boxed_matches = re.findall(r'\\boxed{([^{}]+(?:{[^{}]*}[^{}]*)*)}', text)
    
    if boxed_matches:
        # Return the last one, which is typically the final answer
        answer_content = boxed_matches[-1]
        logging.info(f"Found boxed answer content: {answer_content}")
        return answer_content
    
    # Try a simpler pattern as a fallback
    simple_matches = re.findall(r'\\boxed{([^}]*)}', text)
    if simple_matches:
        answer_content = simple_matches[-1]
        logging.info(f"Found simple boxed answer content: {answer_content}")
        return answer_content
    
    logging.warning("No boxed answer found in the text")
    return ""

def load_dataset_chunk(combined_data_path, start_idx=0, chunk_size=10, total_samples=None):
    """
    Load a specific chunk of samples from the combined seed data.
    
    Args:
        combined_data_path: Path to the combined data file
        start_idx: Starting index for the chunk
        chunk_size: Size of the chunk to load
        total_samples: Optional total sample count for validation
        
    Returns:
        A chunk of samples from the dataset
    """
    logging.info(f"Loading dataset chunk from {combined_data_path} (start: {start_idx}, size: {chunk_size})")
    
    try:
        with open(combined_data_path, 'r') as f:
            # If we know the total count is small or if we need the count, load everything
            if total_samples is None or total_samples <= 10000:
                combined_data = json.load(f)
                if total_samples is None:
                    total_samples = len(combined_data)
            else:
                # For large datasets, seek to the appropriate position and only load what we need
                # This is a simplified approach - in a real implementation, you might want to use
                # a more efficient method like reading line by line or using a database
                combined_data = []
                for i, line in enumerate(f):
                    if i >= start_idx and i < start_idx + chunk_size:
                        sample = json.loads(line.strip())
                        combined_data.append(sample)
                    elif i >= start_idx + chunk_size:
                        break
        
        end_idx = min(start_idx + chunk_size, total_samples)
        chunk = combined_data[start_idx:end_idx]
        logging.info(f"Loaded chunk with {len(chunk)} samples (indices {start_idx}-{end_idx-1})")
        
        return chunk
    except Exception as e:
        logging.error(f"Error loading dataset chunk: {str(e)}")
        return []

def process_sample(sample, api_key):
    """Process a single sample with R1 API and extract the triple"""
    question = sample.get("question", "")
    expected_answer = sample.get("answer", "")
    
    if not question:
        logging.warning("Empty question found in sample")
        return None
    
    try:
        # Call R1 API
        thinking, answer_text = r1_api_call(question, api_key)
        
        # Extract the boxed answer (stripping \boxed{} notation)
        combined_text = thinking + " " + answer_text
        
        # Look for boxed content with a more robust regex
        boxed_matches = re.findall(r'\\boxed{([^{}]+(?:{[^{}]*}[^{}]*)*)}', combined_text)
        
        if boxed_matches:
            # Use the content inside the last boxed notation, which is typically the final answer
            boxed_answer_content = boxed_matches[-1]
            logging.info(f"Found boxed answer content: {boxed_answer_content}")
        else:
            # Try a simpler pattern as a fallback
            simple_matches = re.findall(r'\\boxed{([^}]*)}', combined_text)
            if simple_matches:
                boxed_answer_content = simple_matches[-1]
                logging.info(f"Found simple boxed answer content: {boxed_answer_content}")
            else:
                # If still no boxed answer, try to extract conclusion from the thinking
                conclusion_patterns = [
                    r".*?conclusion.*?is\s+(\d+|\w+)",
                    r".*?answer\s+is\s+(\d+|\w+)",
                    r".*?k\s+=\s+(\d+)",
                    r".*?k\s+is\s+(\d+)",
                    r".*?value\s+of\s+k\s+is\s+(\d+)",
                    r".*?the\s+maximum\s+k\s+.*?is\s+(\d+)",
                    r".*?final\s+answer\s+is\s+(\d+|\w+)",
                    r".*?students\s+is\s+(\d+)",  # Pattern specific to our thinking sample
                    r".*?\*\*Conclusion\*\*:.*?students\s+is\s+(\d+)",  # Markdown format in thinking
                ]
                
                for pattern in conclusion_patterns:
                    match = re.search(pattern, combined_text, re.IGNORECASE)
                    if match:
                        boxed_answer_content = match.group(1).strip()
                        logging.info(f"Extracted answer from conclusion: {boxed_answer_content}")
                        break
                else:
                    # If no conclusion found, try to extract the final number from the thinking
                    numbers = re.findall(r"(\d+)", thinking[-100:])
                    if numbers:
                        boxed_answer_content = numbers[-1]
                        logging.info(f"Last number in thinking: {boxed_answer_content}")
                    else:
                        # If still no answer found, use the clean extracted answer as last resort
                        boxed_answer_content = extract_final_answer(answer_text)
                        logging.info(f"Using clean extracted answer: {boxed_answer_content[:100]}...")
        
        # Create the triple with the answer content (without \boxed{})
        triple = {
            "question": question,
            "thinking": thinking,
            "answer": boxed_answer_content,  # The answer content without \boxed{} notation
            "expected_answer": expected_answer,
            "domain": sample.get("domain", "unknown"),
            "dataset": sample.get("dataset", "unknown"),
        }
        
        return triple
    except Exception as e:
        logging.error(f"Error processing sample: {str(e)}")
        return None

def process_chunk_concurrent(chunk, api_key, output_dir, max_workers=5, save_individual=False):
    """
    Process a chunk of samples concurrently using a thread pool.
    
    Args:
        chunk: List of samples to process
        api_key: DeepSeek API key
        output_dir: Output directory for results
        max_workers: Maximum number of concurrent workers
        save_individual: Whether to save individual sample files
        
    Returns:
        List of successfully processed triples
    """
    successful_triples = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create a dictionary of future: chunk_index
        future_to_index = {executor.submit(process_sample, sample, api_key): i 
                          for i, sample in enumerate(chunk)}
        
        # Process completed futures as they complete
        for future in tqdm(as_completed(future_to_index), total=len(future_to_index), 
                          desc=f"Processing chunk of {len(chunk)} samples"):
            idx = future_to_index[future]
            try:
                triple = future.result()
                if triple:
                    successful_triples.append(triple)
                    
                    # Save individual triple only if requested (for debugging)
                    if save_individual:
                        sample_path = os.path.join(output_dir, f"sample_{idx}.json")
                        with open(sample_path, 'w') as f:
                            json.dump(triple, f, indent=2)
                        logging.info(f"Saved individual triple {idx} to {sample_path}")
            except Exception as e:
                logging.error(f"Error processing sample {idx}: {str(e)}")
    
    return successful_triples

def load_existing_dataset(output_dir):
    """
    Load the existing dataset from the output directory.
    
    Args:
        output_dir: Path to the output directory
        
    Returns:
        List of existing samples or empty list if no dataset exists
    """
    main_dataset_path = os.path.join(output_dir, "samples.json")
    
    if os.path.exists(main_dataset_path):
        try:
            with open(main_dataset_path, 'r') as f:
                existing_samples = json.load(f)
                logging.info(f"Loaded {len(existing_samples)} existing samples from {main_dataset_path}")
                return existing_samples
        except Exception as e:
            logging.warning(f"Error loading existing dataset: {str(e)}. Starting fresh.")
    else:
        logging.info(f"No existing dataset found at {main_dataset_path}. Starting fresh.")
    
    return []

def update_dataset_files(output_dir, all_samples):
    """
    Update the dataset files with the combined samples.
    
    Args:
        output_dir: Path to the output directory
        all_samples: List of all samples to save
        
    Returns:
        None
    """
    main_dataset_json_path = os.path.join(output_dir, "samples.json")
    main_dataset_jsonl_path = os.path.join(output_dir, "samples.jsonl")
    
    # Save JSON version
    with open(main_dataset_json_path, 'w') as f:
        json.dump(all_samples, f, indent=2)
    
    # Save JSONL version (more efficient for large datasets)
    with open(main_dataset_jsonl_path, 'w') as f:
        for triple in all_samples:
            f.write(json.dumps(triple) + "\n")
    
    logging.info(f"Updated dataset files with {len(all_samples)} total samples")

def main():
    """Main entry point for generating ST dataset samples"""
    
    # Display important warnings about DeepSeek R1 API
    logging.warning("⚠️ IMPORTANT: DeepSeek R1 API responses take time (5-15 minutes per call)")
    logging.warning("⚠️ The API may return malformed JSON responses despite 200 status codes")
    logging.warning("⚠️ The script will automatically retry with increasing delays")
    logging.warning("⚠️ Consider reducing max_workers if you experience too many failed attempts")
    
    parser = argparse.ArgumentParser(description="Generate SmolTraces dataset with concurrent processing")
    parser.add_argument("--start_idx", type=int, default=0, help="Starting index for processing")
    parser.add_argument("--num_samples", type=int, default=10, help="Number of samples to process")
    parser.add_argument("--max_workers", type=int, default=3, help="Maximum number of concurrent workers")
    parser.add_argument("--chunk_size", type=int, default=5, help="Size of each processing chunk")
    parser.add_argument("--output_dir", type=str, default="datasets/SmolTraces-R1", help="Output directory")
    parser.add_argument("--save_individual", action="store_true", help="Save individual sample files for debugging")
    args = parser.parse_args()
    
    # Load API key
    keys = load_env_keys()
    api_key = keys.get("deepseek")
    
    if not api_key:
        logging.error("DeepSeek API key not found. Please check your .env file.")
        return
    
    # Parameters
    combined_data_path = "datasets/combined_seed_data.json"
    output_dir = args.output_dir
    start_idx = args.start_idx
    num_samples = args.num_samples
    max_workers = args.max_workers
    chunk_size = args.chunk_size
    save_individual = args.save_individual
    
    # Calculate estimated completion time (DeepSeek R1 takes ~10-15 minutes per API call)
    estimated_time_per_sample = 15 * 60  # 15 minutes in seconds
    total_estimated_time = num_samples * estimated_time_per_sample / max_workers
    estimated_completion_time = time.time() + total_estimated_time
    
    logging.info(f"Processing {num_samples} samples with {max_workers} workers")
    logging.info(f"Estimated completion time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(estimated_completion_time))}")
    logging.info(f"(approximately {total_estimated_time/3600:.1f} hours)")
    
    # Create output directory if it doesn't exist
    os.makedirs(output_dir, exist_ok=True)
    
    # Load existing dataset if available
    existing_samples = load_existing_dataset(output_dir)
    all_samples = existing_samples.copy()  # Start with existing samples
    
    # Load all samples to process
    all_input_samples = []
    for chunk_start in range(start_idx, start_idx + num_samples, chunk_size):
        # Adjust chunk size for the last chunk if needed
        current_chunk_size = min(chunk_size, start_idx + num_samples - chunk_start)
        chunk = load_dataset_chunk(combined_data_path, chunk_start, current_chunk_size)
        if chunk:
            all_input_samples.extend([(i, sample) for i, sample in enumerate(chunk, start=chunk_start)])
    
    # Process samples concurrently using a single ThreadPoolExecutor
    # This avoids nested thread pools and ensures each sample gets processed in parallel
    successful_count = 0
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Create futures for all samples
        futures = []
        for idx, sample in all_input_samples:
            future = executor.submit(process_sample, sample, api_key)
            futures.append((future, idx))
        
        # Process completed futures as they complete
        for future, idx in tqdm(futures, desc=f"Processing {len(all_input_samples)} samples"):
            try:
                triple = future.result()
                if triple:
                    all_samples.append(triple)
                    successful_count += 1
                    
                    # Save individual sample if requested
                    if save_individual:
                        sample_path = os.path.join(output_dir, f"sample_{idx}.json")
                        with open(sample_path, 'w') as f:
                            json.dump(triple, f, indent=2)
                    
                    # Update combined files periodically (every 10 successful samples)
                    if successful_count % 10 == 0:
                        update_dataset_files(output_dir, all_samples)
                        logging.info(f"Updated dataset with latest samples (total: {len(all_samples)})")
            except Exception as e:
                logging.error(f"Error processing sample {idx}: {str(e)}")
    
    # Final update to dataset files
    if successful_count > 0:
        update_dataset_files(output_dir, all_samples)
        
        # Create dataset metadata
        metadata = {
            "dataset_name": "SmolTraces-R1",
            "description": "High-quality reasoning traces generated by the DeepSeek-R1 model",
            "version": "1.0.0",
            "creator": "DeepSeek-R1 via API",
            "date_created": time.strftime("%Y-%m-%d"),
            "date_updated": time.strftime("%Y-%m-%d %H:%M:%S"),
            "num_samples": len(all_samples),
            "new_samples_added": successful_count,
            "license": "Research use only"
        }
        
        with open(os.path.join(output_dir, "metadata.json"), 'w') as f:
            json.dump(metadata, f, indent=2)
        
        logging.info(f"Successfully processed {successful_count}/{num_samples} samples")
        logging.info(f"Total dataset size: {len(all_samples)} samples")
        logging.info(f"Dataset metadata saved to {output_dir}/metadata.json")
    else:
        logging.error("No samples were successfully processed")

if __name__ == "__main__":
    main() 