"""
Script for generating SmolTraces-HardCoded (ST-HC) datasets.
These datasets have hardcoded reasoning traces with specific structures and pivot types.
"""

import os
import sys
import json
import argparse
import logging
import time
import random
import re
from typing import Dict, List, Any, Optional, Tuple
import requests
from tqdm import tqdm
from datasets import load_dataset, Dataset

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_generation.synthetic_trace_generation import extract_pivots

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
MAX_ATTEMPTS = 3
SLEEP_TIME = 1
RESULTS_DIR = "results/hardcoded_traces"
API_BASE_URL = "https://api.openai.com/v1/chat/completions"

def count_tokens(text: str) -> int:
    """
    Count the number of tokens in a text using a simple approximation.
    
    Args:
        text: The text to count tokens for
        
    Returns:
        Approximate token count
    """
    # Simple approximation: 4 characters ~ 1 token
    return len(text.split())

def gpt4o_api_call(prompt: str, api_key: str, max_tokens: int = 4000, temperature: float = 0.2) -> Optional[str]:
    """
    Make an API call to GPT-4o.
    
    Args:
        prompt: The prompt for GPT-4o
        api_key: OpenAI API key
        max_tokens: Maximum tokens to generate
        temperature: Temperature for generation
        
    Returns:
        Generated response or None if failed
    """
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    data = {
        "model": "gpt-4o-2024-11-20",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": max_tokens,
        "temperature": temperature
    }
    
    for attempt in range(MAX_ATTEMPTS):
        try:
            response = requests.post(API_BASE_URL, headers=headers, json=data)
            response.raise_for_status()
            result = response.json()
            return result["choices"][0]["message"]["content"]
        except Exception as e:
            logging.error(f"API call failed on attempt {attempt+1}: {e}")
            if attempt < MAX_ATTEMPTS - 1:
                logging.info(f"Retrying in {SLEEP_TIME} seconds...")
                time.sleep(SLEEP_TIME)
            else:
                logging.error("All attempts failed")
                return None

def create_hardcoded_prompt(question: str) -> str:
    """
    Create a prompt for generating a synthetic trace with a hardcoded structure.
    
    Args:
        question: The question to generate a trace for
        
    Returns:
        Prompt text
    """
    prompt = f"""This task requires solving problems using structured, real-time reasoning, including explicit self-monitoring and self-correction. Mimic the thought process of an agent that regularly pauses to reconsider assumptions, verify intermediate results, explore alternatives, and integrate findings into coherent solutions. Use explicit lexical pivots to signal shifts in thinking or corrections to your reasoning. Examples include:

Realization pivots (recognizing errors or oversights): "Wait—", "Oh—", "Actually -", "I missed something -".  

Verification pivots (explicitly testing assumptions or results): "Let me double-check -", "To verify—", "Checking again -".  

Exploration pivots (considering alternative approaches): "What if -", "Another way to look at this -", "Alternatively -".  

Integration pivots (synthesizing different ideas or resolving contradictions): "Now I see how -", "This connects back to -", "Putting this together -".

When solving the problem, follow a structured reasoning trace that clearly moves through the following stages:

1. Problem Framing: Restate the problem and identify key elements clearly.

2. Exploration: Consider one or more potential solution paths, openly weighing alternatives.

3. Verification: Explicitly test intermediate results or assumptions; if inconsistencies arise, pivot explicitly to clarify or correct.

4. Synthesis: Clearly integrate findings into a coherent solution, explicitly connecting back to the original problem.

Use direct, concise language. Short sentences should represent your evolving thoughts clearly. Use pivots naturally to signal shifts in reasoning, corrections, or deeper insights. Be explicit about confusion or uncertainty when it arises (e.g., "Something doesn't add up here," or "I'm not sure yet—").  

IMPORTANT: For your final answer, use the LaTeX \\boxed{{}} notation to enclose your answer. For example:
- If the answer is a number: \\boxed{{42}}
- If the answer is a letter choice: \\boxed{{A}}
- If the answer is an expression: \\boxed{{x^2 + 2x - 3}}
- If the answer is a word: \\boxed{{True}}

Here are examples of well-formatted answers from similar problems:
- "Since $1 \\le \\sqrt{{1}} < \\sqrt{{2}} < \\sqrt{{3}} < 2$, the first three terms are 1. There are 5 terms equal to 2, and 7 terms equal to 3. The last term is 4. So the sum is $3(1) + 5(2) + 7(3) + 4 = \\boxed{{38}}$."
- "The angles in a triangle sum to $180^{{\\circ}}$. So $x + 53^{{\\circ}} + 37^{{\\circ}} = 180^{{\\circ}}$, meaning $x = 180^{{\\circ}} - 53^{{\\circ}} - 37^{{\\circ}} = 90^{{\\circ}}$. So the answer is $\\boxed{{90^{{\\circ}}}}$."
- "Solving the inequality, we get $-3 < x < 4$. Among the given options, only $x = 2$ falls within this range. The answer is $\\boxed{{B}}$."
- "Following the constraints, the only valid arrangement places goods in this order: F, M, T, P, L, G. Therefore, statement D must be false. The answer is $\\boxed{{D}}$."

Here is the problem you need to solve:

{question}

Begin your reasoning with "<|begin_of_thought|>" and end with "<|end_of_thought|>".
After your reasoning trace, state the final answer with "<|begin_of_solution|>" and "<|end_of_solution|>".
Remember that your final answer must use the \\boxed{{}} notation.
"""
    
    return prompt

def extract_letter(ans_text):
    """
    Extract answer from text, handling different formats including LaTeX boxed format.
    """
    if not ans_text:
        return ""
    
    # First, check for \boxed{} LaTeX format
    boxed_match = re.search(r'\\boxed{([^}]*)}', ans_text)
    if boxed_match:
        result = boxed_match.group(1).strip()
        # If the boxed content is a single letter A-D, convert to uppercase
        if re.match(r'^[A-Da-d]$', result):
            return result.upper()
        return result
    
    # Try to match patterns like "(A)" or "A." or "A:" or just "A"
    letter_match = re.search(r'^\(?([A-D])[.):]?', ans_text.strip())
    if letter_match:
        return letter_match.group(1).upper()
    
    # Try to find "The correct answer is X" pattern
    answer_match = re.search(r'(?:answer|option|statement)\s+(?:is|:)?\s*\(?([A-D])\)?', ans_text.lower())
    if answer_match:
        return answer_match.group(1).upper()
    
    # Try to find "\*(X)\*" or "**(X)**" format
    marked_match = re.search(r'\*+\(?([A-D])\)\*+', ans_text)
    if marked_match:
        return marked_match.group(1).upper()
    
    # Look for letter answer in the text with surrounding markers
    letter_in_text = re.search(r'[^\w]([A-D])[^\w]', ' ' + ans_text + ' ')
    if letter_in_text:
        return letter_in_text.group(1).upper()
    
    # If we can't extract a specific pattern, return the original text
    return ans_text

def compare_answers(generated, ground_truth):
    """
    Compare generated answer with ground truth, handling different formats
    """
    # If either answer is empty, they won't match
    if not generated or not ground_truth:
        return False
    
    # Clean the boxed contents for better matching
    def clean_boxed_content(text):
        boxed_matches = re.findall(r'\\boxed{([^}]*)}', text)
        cleaned_matches = []
        for match in boxed_matches:
            # Remove LaTeX formatting like \text{} and normalize spaces
            clean = re.sub(r'\\text{([^}]*)}', r'\1', match)
            clean = re.sub(r'\s+', ' ', clean).strip()
            cleaned_matches.append(clean)
        return cleaned_matches
    
    # If the ground truth is a long explanation with multiple boxed answers
    if len(ground_truth) > 100 and '\\boxed{' in ground_truth:
        # Extract all boxed answers from ground truth and generated answer
        gt_boxed_values = clean_boxed_content(ground_truth)
        gen_boxed_values = clean_boxed_content(generated)
        
        # If we have boxed answers in both, compare them
        if gt_boxed_values and gen_boxed_values:
            # For single boxed answer in both
            if len(gt_boxed_values) == 1 and len(gen_boxed_values) == 1:
                return compare_single_answers(gen_boxed_values[0], gt_boxed_values[0])
            
            # If different number, check if the last generated answer matches any ground truth boxed answer
            if gen_boxed_values:
                last_answer = gen_boxed_values[-1]
                if any(compare_single_answers(last_answer, gt) for gt in gt_boxed_values):
                    return True
                
                # Try comparing all pairs for a match
                for gen in gen_boxed_values:
                    if any(compare_single_answers(gen, gt) for gt in gt_boxed_values):
                        return True
    
    # For direct boxed comparison between answers
    gen_boxed = re.search(r'\\boxed{([^}]*)}', generated)
    gt_boxed = re.search(r'\\boxed{([^}]*)}', ground_truth)
    
    if gen_boxed and gt_boxed:
        gen_content = re.sub(r'\\text{([^}]*)}', r'\1', gen_boxed.group(1))
        gt_content = re.sub(r'\\text{([^}]*)}', r'\1', gt_boxed.group(1))
        return compare_single_answers(gen_content, gt_content)
    
    # Extract answers for standard comparison
    extracted_gen = extract_letter(generated)
    extracted_truth = extract_letter(ground_truth)
    
    return compare_single_answers(extracted_gen, extracted_truth)

def compare_single_answers(generated, ground_truth):
    """
    Compare a single extracted answer with ground truth
    """
    # If either is empty after extraction, they won't match
    if not generated or not ground_truth:
        return False
        
    # Clean the answers for comparison
    generated = generated.strip()
    ground_truth = ground_truth.strip()
    
    # Special case for multiple-choice options (A/B/C/D)
    if re.match(r'^[A-Da-d]$', generated) and re.match(r'^[A-Da-d]$', ground_truth):
        return generated.upper() == ground_truth.upper()
    
    # For numerical answers, try to normalize and compare
    if re.match(r'^-?\d+(\.\d+)?$', generated) and re.match(r'^-?\d+(\.\d+)?$', ground_truth):
        try:
            gen_num = float(generated)
            truth_num = float(ground_truth)
            return abs(gen_num - truth_num) < 1e-6
        except ValueError:
            pass
    
    # For fractions like "1/2" or expressions with mathematical symbols
    try:
        # Try to handle LaTeX sequences for special symbols
        gen_clean = re.sub(r'\\phi', 'φ', generated)
        truth_clean = re.sub(r'\\phi', 'φ', ground_truth)
        
        # Handle common alternatives for the same mathematical entity
        if ('φ' in gen_clean and '\\frac{1 + \\sqrt{5}}{2}' in ground_truth) or \
           ('φ' in truth_clean and '\\frac{1 + \\sqrt{5}}{2}' in generated):
            return True
            
        if ('-\\frac{1}{φ}' in gen_clean and '\\frac{1 - \\sqrt{5}}{2}' in ground_truth) or \
           ('-\\frac{1}{φ}' in truth_clean and '\\frac{1 - \\sqrt{5}}{2}' in generated):
            return True
    except:
        pass
    
    # Handle expressions that may have different forms but same meaning
    # This is a simplified check - more robust would need a symbolic computation engine
    if (generated in ground_truth) or (ground_truth in generated):
        return True
        
    # Default to direct string comparison
    return generated == ground_truth

def process_question(question: str, question_id: str, api_key: str, output_dir: str, ground_truth_answer: str = None, max_attempts: int = 5, verbose: bool = False) -> Dict[str, Any]:
    """
    Process a single question to generate a hardcoded reasoning trace.
    
    Args:
        question: The question to process
        question_id: Unique ID for the question
        api_key: OpenAI API key
        output_dir: Directory to save results
        ground_truth_answer: The ground truth answer to verify against
        max_attempts: Maximum number of attempts to generate a correct answer
        verbose: Whether to print detailed information
        
    Returns:
        Dictionary with the processed result including number of attempts
    """
    os.makedirs(output_dir, exist_ok=True)
    
    if verbose:
        logging.info(f"\n{'='*80}\nPROCESSING QUESTION {question_id}\n{'='*80}")
        logging.info(f"QUESTION: {question}")
        if ground_truth_answer:
            logging.info(f"GROUND TRUTH ANSWER: {ground_truth_answer}")
    
    # Try up to max_attempts times to get a correct answer
    for attempt in range(max_attempts):
        # Create prompt
        prompt = create_hardcoded_prompt(question)
        
        if verbose:
            logging.info(f"\n{'-'*40}\nAttempt {attempt+1}/{max_attempts}\n{'-'*40}")
        
        # Call API
        response = gpt4o_api_call(prompt, api_key)
        
        if not response:
            logging.error(f"Failed to generate response for question {question_id} on attempt {attempt+1}")
            continue
        
        # Extract thinking and answer
        thinking_match = re.search(r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>", response, re.DOTALL)
        answer_match = re.search(r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>", response, re.DOTALL)
        
        thinking = thinking_match.group(1).strip() if thinking_match else ""
        answer = answer_match.group(1).strip() if answer_match else ""
        
        # If markers weren't found, try to extract thinking and answer from the full response
        if not thinking and not answer:
            parts = response.split("\n\n")
            if len(parts) >= 2:
                thinking = "\n\n".join(parts[:-1])
                answer = parts[-1]
        
        if verbose:
            logging.info(f"THINKING:\n{thinking}")
            logging.info(f"GENERATED ANSWER: {answer}")
        
        # Check answer against ground truth if provided
        answer_correct = True
        if ground_truth_answer and answer:
            # Use the robust comparison function 
            answer_correct = compare_answers(answer, ground_truth_answer)
            
            # Extract answers for logging
            extracted_answer = extract_letter(answer)
            extracted_ground_truth = extract_letter(ground_truth_answer)
            
            if not answer_correct:
                if verbose:
                    logging.info(f"ANSWER COMPARISON: INCORRECT")
                    logging.info(f"Generated answer (raw): '{answer}'")
                    logging.info(f"Generated (extracted): '{extracted_answer}'")
                    logging.info(f"Ground truth (raw): '{ground_truth_answer}'")
                    logging.info(f"Ground truth (extracted): '{extracted_ground_truth}'")
                    
                    # Check for specific comparison issues
                    if '\\boxed{' in ground_truth_answer and len(ground_truth_answer) > 100:
                        gt_boxed = re.findall(r'\\boxed{([^}]*)}', ground_truth_answer)
                        gen_boxed = re.findall(r'\\boxed{([^}]*)}', answer)
                        if gt_boxed and not gen_boxed:
                            logging.info(f"COMPARISON ISSUE: Ground truth has {len(gt_boxed)} boxed answers, but generated answer has none")
                        elif gt_boxed and gen_boxed:
                            if len(gt_boxed) != len(gen_boxed):
                                logging.info(f"COMPARISON ISSUE: Ground truth has {len(gt_boxed)} boxed answers, but generated answer has {len(gen_boxed)}")
                            else:
                                mismatches = [(i, gt, gen) for i, (gt, gen) in enumerate(zip(gt_boxed, gen_boxed)) if gt.strip() != gen.strip()]
                                if mismatches:
                                    logging.info(f"COMPARISON ISSUE: Mismatches in boxed answers:")
                                    for i, gt, gen in mismatches:
                                        logging.info(f"  Answer {i+1}: Ground truth '{gt}' vs Generated '{gen}'")
                    elif extracted_answer != extracted_ground_truth:
                        logging.info(f"COMPARISON ISSUE: Simple mismatch between extracted answers")
                    else:
                        logging.info(f"COMPARISON ISSUE: Unknown comparison failure")
                        
                logging.info(f"Attempt {attempt+1}: Answer incorrect for question {question_id}")
                # Continue to next attempt
                continue
            else:
                if verbose:
                    logging.info(f"ANSWER COMPARISON: CORRECT")
                    logging.info(f"Extracted answer: '{extracted_answer}'")
                    logging.info(f"Ground truth: '{extracted_ground_truth}'")
                    
            # If answer is correct, break out of the loop
            if answer_correct:
                break
        
        # If we reached here, either the answer is correct or no ground truth was provided
        result_file = os.path.join(output_dir, f"{question_id}.json")
        
        # Create the result dictionary
        result = {
            "question_id": question_id,
            "question": question,
            "thinking": thinking,
            "answer": answer,
            "ground_truth_answer": ground_truth_answer,
            "answer_correct": answer_correct,
            "attempts": attempt + 1
        }
        
        # Save the result to a file
        with open(result_file, "w") as f:
            json.dump(result, f, indent=2)
        
        if verbose:
            logging.info(f"Saved result to {result_file}")
        
        return result
    
    # If we reached here, all attempts failed
    logging.error(f"All {max_attempts} attempts failed for question {question_id}")
    return {
        "question_id": question_id,
        "question": question,
        "thinking": "",
        "answer": "",
        "ground_truth_answer": ground_truth_answer,
        "answer_correct": False,
        "attempts": max_attempts
    }

def generate_st_hc_dataset(input_dataset_path: str, num_samples: int, api_key: str, output_dir: str, verbose: bool = False) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Generate the SmolTraces-HardCoded (ST-HC) dataset.
    
    Args:
        input_dataset_path: Path to the input dataset
        num_samples: Maximum number of samples to generate
        api_key: OpenAI API key
        output_dir: Directory to save results
        verbose: Whether to print detailed information
        
    Returns:
        Tuple of (results list, attempts data list)
    """
    os.makedirs(output_dir, exist_ok=True)
    results_dir = os.path.join(output_dir, "individual")
    os.makedirs(results_dir, exist_ok=True)
    
    # Load input dataset
    input_dataset = load_dataset(input_dataset_path)
    if isinstance(input_dataset, dict):
        input_dataset = input_dataset["train"]
    
    # Limit to requested number of samples
    input_dataset = input_dataset.shuffle(seed=42).select(range(min(len(input_dataset), num_samples)))
    
    # Process dataset
    results = []
    all_attempts = []  # Track attempts for all questions
    
    for i, example in enumerate(tqdm(input_dataset, desc="Generating ST-HC traces")):
        # Extract question and answer
        if "question" in example:
            question = example["question"]
        else:
            logging.warning(f"No question found in example {i}")
            continue
            
        # Extract ground truth answer if available
        ground_truth_answer = None
        if "answer" in example:
            ground_truth_answer = example["answer"]
        elif "solution" in example:
            ground_truth_answer = example["solution"]
        
        # Process question
        question_id = f"sthc_{i:04d}"
        result = process_question(
            question, 
            question_id, 
            api_key, 
            results_dir, 
            ground_truth_answer=ground_truth_answer,
            verbose=verbose
        )
        
        # Record attempt count whether successful or not
        all_attempts.append({
            "id": question_id,
            "success": result["answer_correct"],
            "attempts": result["attempts"]
        })
        
        if result["answer_correct"]:
            # Extract pivots
            pivots = extract_pivots(result["thinking"])
            pivot_stats = {p_type: len(instances) for p_type, instances in pivots.items()}
            
            # Add pivot information
            result["pivot_stats"] = pivot_stats
            results.append(result)
    
    # Save the results in JSONL format
    if results:
        output_jsonl_path = os.path.join(output_dir, "st_hc_dataset.jsonl")
        with open(output_jsonl_path, "w") as f:
            for result in results:
                f.write(json.dumps(result) + "\n")
        
        # Also save as a single JSON file
        output_json_path = os.path.join(output_dir, "st_hc_dataset.json")
        with open(output_json_path, "w") as f:
            json.dump(results, f, indent=2)
    
    # Save attempts data
    attempts_path = os.path.join(output_dir, "attempts_data.json")
    with open(attempts_path, "w") as f:
        json.dump(all_attempts, f, indent=2)
    
    logging.info(f"Successfully generated {len(results)} examples out of {len(input_dataset)} attempts")
    
    # Return results and attempts for analysis
    return results, all_attempts

def analyze_attempts(attempts_data: List[Dict[str, Any]], output_dir: str) -> None:
    """
    Analyze the attempt statistics and save to a file.
    
    Args:
        attempts_data: Data about all attempts made
        output_dir: Directory to save results
    """
    # Compute attempt metrics
    questions_attempted = len(attempts_data)
    questions_answered = sum(1 for attempt in attempts_data if attempt["success"])
    total_attempts = sum(attempt["attempts"] for attempt in attempts_data)
    avg_attempts = total_attempts / questions_attempted if questions_attempted > 0 else 0
    success_rate = (questions_answered / questions_attempted * 100) if questions_attempted > 0 else 0
    
    # Build analysis dictionary
    analysis = {
        "questions_attempted": questions_attempted,
        "questions_answered": questions_answered,
        "success_rate": success_rate,
        "total_attempts": total_attempts,
        "avg_attempts_per_question": avg_attempts,
    }
    
    # Log the analysis
    logging.info(f"Attempt analysis: {json.dumps(analysis, indent=2)}")
    
    # Save the analysis to a file
    with open(os.path.join(output_dir, "attempts_analysis.json"), "w") as f:
        json.dump(analysis, f, indent=2)

def analyze_pivots_in_dataset(dataset: List[Dict[str, Any]], attempts_data: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Analyze the pivot types present in the dataset and attempt metrics.
    
    Args:
        dataset: The ST-HC dataset
        attempts_data: Data about all attempts made
        
    Returns:
        Dictionary with analysis results
    """
    # Count pivot types
    pivot_type_counts = {}
    for sample in dataset:
        for pivot_type, count in sample["pivot_stats"].items():
            if pivot_type not in pivot_type_counts:
                pivot_type_counts[pivot_type] = []
            pivot_type_counts[pivot_type].append(count)
    
    # Compute statistics
    thinking_token_counts = [count_tokens(sample["thinking"]) for sample in dataset]
    answer_token_counts = [count_tokens(sample["answer"]) for sample in dataset]
    
    # Compute attempt metrics
    questions_attempted = len(attempts_data)
    questions_answered = len(dataset)
    total_attempts = sum(attempt["attempts"] for attempt in attempts_data)
    avg_attempts = total_attempts / questions_attempted if questions_attempted > 0 else 0
    success_rate = (questions_answered / questions_attempted * 100) if questions_attempted > 0 else 0
    
    # Build analysis dictionary
    analysis = {
        "num_samples": len(dataset),
        "questions_attempted": questions_attempted,
        "questions_answered": questions_answered,
        "success_rate": success_rate,
        "total_attempts": total_attempts,
        "avg_attempts_per_question": avg_attempts,
        "pivot_type_coverage": {pivot_type: len(counts) for pivot_type, counts in pivot_type_counts.items()},
        "pivot_type_percent": {pivot_type: len(counts) / len(dataset) * 100 
                             for pivot_type, counts in pivot_type_counts.items() if len(dataset) > 0},
        "pivot_type_avg_counts": {pivot_type: sum(counts) / len(counts) if counts else 0 
                                for pivot_type, counts in pivot_type_counts.items()},
        "thinking_tokens_avg": sum(thinking_token_counts) / len(thinking_token_counts) if thinking_token_counts else 0,
        "answer_tokens_avg": sum(answer_token_counts) / len(answer_token_counts) if answer_token_counts else 0,
    }
    
    return analysis

def main(args):
    # Load the dataset
    dataset = load_dataset(args.input_dataset)
    
    # Get the appropriate split (try test first, then train if test doesn't exist)
    if "test" in dataset:
        dataset = dataset["test"]
    elif "train" in dataset:
        dataset = dataset["train"]
    else:
        # If neither exists, use the first available split
        first_split = next(iter(dataset.keys()))
        dataset = dataset[first_split]
        logging.info(f"Using '{first_split}' split from dataset")
    
    # Set up logging
    logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
    
    # Filter and preprocess the examples
    examples = [
        {
            "id": str(idx).zfill(4),  # Give each a unique ID
            "question": example["question"],
            "ground_truth_answer": example.get("answer", None)
        }
        for idx, example in enumerate(dataset)
    ]
    
    if args.num_samples and args.num_samples < len(examples):
        examples = examples[:args.num_samples]
    
    logging.info(f"Loaded {len(examples)} examples")
    
    # Process each example
    all_attempts = []
    successful_examples = []
    
    # Process examples
    for example in tqdm(examples, desc="Generating ST-HC traces"):
        # Skip examples without ground truth if verification is required
        if args.verify_answers and not example.get("ground_truth_answer"):
            logging.warning(f"Skipping question {example['id']} as no ground truth answer is available")
            continue
        
        # Process the question
        result = process_question(
            question=example["question"],
            question_id=f"sthc_{example['id']}",
            api_key=args.api_key,
            output_dir=args.output_dir,
            ground_truth_answer=example.get("ground_truth_answer"),
            max_attempts=args.max_attempts,
            verbose=args.verbose
        )
        
        # Record the result
        all_attempts.append({
            "id": example['id'],
            "success": result["answer_correct"],
            "attempts": result["attempts"]
        })
        
        if result["answer_correct"]:
            try:
                # Extract pivots only if there's thinking to analyze
                if result["thinking"]:
                    pivots = extract_pivots(result["thinking"])
                    # Update the result with pivots
                    result["pivots"] = pivots
                else:
                    result["pivots"] = []
                
                # Add to list of successful examples
                successful_examples.append(result)
            except Exception as e:
                logging.error(f"Error extracting pivots for question {example['id']}: {str(e)}")
                # Still add to successful examples even if pivot extraction fails
                result["pivots"] = []
                successful_examples.append(result)
    
    # Analyze the results
    analyze_attempts(all_attempts, args.output_dir)
    
    # Log successful examples
    logging.info(f"Successfully generated {len(successful_examples)} examples out of {len(examples)} attempts")
    
    # Return the results for further analysis
    return {
        "successful_examples": successful_examples,
        "all_attempts": all_attempts
    }

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Generate SmolTraces-HardCoded dataset")
    parser.add_argument("--input_dataset", type=str, required=True,
                        help="Path to the input dataset")
    parser.add_argument("--output_dir", type=str, default=RESULTS_DIR,
                        help="Directory to save results")
    parser.add_argument("--num_samples", type=int, default=10,
                        help="Number of samples to generate")
    parser.add_argument("--max_attempts", type=int, default=5,
                        help="Maximum number of attempts per question")
    parser.add_argument("--api_key", type=str, required=True,
                        help="OpenAI API key")
    parser.add_argument("--verbose", action="store_true",
                        help="Print detailed information about each sample")
    parser.add_argument("--verify_answers", action="store_true",
                        help="Verify answers against ground truth")
    
    args = parser.parse_args()
    
    # Run main function with args
    results = main(args)
    
    # Analyze dataset if there are successful examples
    successful_examples = results["successful_examples"]
    if len(successful_examples) > 0:
        # Analyze dataset
        analysis = analyze_pivots_in_dataset(successful_examples, results["all_attempts"])
        
        # Save analysis
        with open(os.path.join(args.output_dir, "analysis.json"), "w") as f:
            json.dump(analysis, f, indent=2)
        
        # Print summary
        logging.info(f"Generated {len(successful_examples)} ST-HC traces")
        logging.info(f"Average thinking length: {analysis['thinking_tokens_avg']:.2f} tokens")
        logging.info(f"Success rate: {analysis['success_rate']:.2f}%")
        logging.info(f"Average attempts per question: {analysis['avg_attempts_per_question']:.2f}")
        
        # Save summary report
        with open(os.path.join(args.output_dir, "summary_report.md"), "w") as f:
            f.write("# SmolTraces-HardCoded Dataset Generation Summary\n\n")
            
            f.write("## Dataset Statistics\n\n")
            f.write(f"- Number of samples: {analysis['num_samples']}\n")
            f.write(f"- Questions attempted: {analysis['questions_attempted']}\n")
            f.write(f"- Questions answered correctly: {analysis['questions_answered']}\n")
            f.write(f"- Success rate: {analysis['success_rate']:.2f}%\n")
            f.write(f"- Average attempts per question: {analysis['avg_attempts_per_question']:.2f}\n")
            f.write(f"- Average thinking length: {analysis['thinking_tokens_avg']:.2f} tokens\n")
            f.write(f"- Average answer length: {analysis['answer_tokens_avg']:.2f} tokens\n\n")
            
            f.write("## Pivot Type Coverage\n\n")
            f.write("| Pivot Type | Samples | % of Traces | Average Count |\n")
            f.write("|------------|---------|-------------|---------------|\n")
            for pivot_type in sorted(analysis["pivot_type_coverage"].keys()):
                coverage = analysis["pivot_type_coverage"][pivot_type]
                percent = analysis["pivot_type_percent"][pivot_type]
                avg_count = analysis["pivot_type_avg_counts"][pivot_type]
                f.write(f"| {pivot_type} | {coverage} | {percent:.1f}% | {avg_count:.2f} |\n")
            
            f.write("\n\nST-HC dataset saved to: " + args.output_dir)
    else:
        logging.warning("No analysis performed as no successful examples were generated.") 