"""
Script for analyzing reasoning traces to identify specific pivot types and structural patterns.

This analysis focuses on the four pivot categories described in the paper:
1. Realization pivots (e.g., "Wait", "Oh") - signal recognition of errors or oversights
2. Verification pivots (e.g., "Let me check") - explicitly validate intermediate hypotheses
3. Exploration pivots (e.g., "What if", "Another approach") - prompt consideration of alternative solution paths
4. Integration pivots (e.g., "Now I see how") - synthesize previously explored ideas into a coherent solution

And the four reasoning stages:
1. Problem framing - restating and clarifying key aspects of the problem
2. Exploration - considering hypotheses or potential solution paths
3. Verification - explicit checking of hypotheses or intermediate results
4. Synthesis - integrating insights from earlier stages into a coherent solution

The analysis examines 1,000 successful reasoning traces as specified in the paper.
"""

import os
import re
import json
import argparse
import logging
import time
import random
import requests
from typing import Dict, List, Any, Optional, Tuple
from collections import Counter, defaultdict
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from datasets import load_dataset
from tqdm import tqdm
from dotenv import load_dotenv

# Constants for API calls
MAX_ATTEMPTS = 3
SLEEP_TIME = 2

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Custom JSON encoder to handle NumPy types
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NumpyEncoder, self).default(obj)

# Define pivot patterns based on the four categories in the paper
PIVOT_PATTERNS = {
    "realization": r"(Wait[,\s\-—]|Oh[,\s\-—]|Actually[,\s\-—]|I missed[,\s\-—]|I realize[d]?[,\s\-—]|I notice[d]?[,\s\-—]|I see now[,\s\-—]|I overlooked[,\s\-—]|Hmm[,\s\-—]|I forgot[,\s\-—]|I didn'?t consider[,\s\-—]|Let'?s reconsider[,\s\-—]|Upon reflection[,\s\-—]|On second thought[,\s\-—]|I made a mistake[,\s\-—]|That's not right[,\s\-—]|I need to correct[,\s\-—]|I was wrong[,\s\-—]|Oops[,\s\-—])",
    "verification": r"(Let'?s?\s+check|Let'?s?\s+verify|Let'?s?\s+validate|Let'?s?\s+ensure|Let'?s?\s+test|Let'?s?\s+make\s+sure|Let'?s?\s+double[\-\s]check)",
    "exploration": r"(What if[,\s\-—]|Another (way|approach|method|strategy|possibility|option|alternative)[,\s\-—]|Alternatively[,\s\-—]|Let'?s try[,\s\-—]|We could also[,\s\-—]|I could[,\s\-—]|We could[,\s\-—]|Maybe[,\s\-—]|Perhaps[,\s\-—]|I can try[,\s\-—]|Let'?s explore[,\s\-—]|Let'?s consider[,\s\-—]|A different approach[,\s\-—]|We might[,\s\-—]|I might[,\s\-—]|Instead[,\s\-—]|Let'?s attempt[,\s\-—]|One possibility[,\s\-—]|Focusing on[,\s\-—])",
    "integration": r"(Now I see[,\s\-—]|This connects[,\s\-—]|Putting (this|these) together[,\s\-—]|So combining[,\s\-—]|In summary[,\s\-—]|To synthesize[,\s\-—]|Bringing (this|these) together[,\s\-—]|Combining (this|these)[,\s\-—]|Now (we|I) understand[,\s\-—]|Now (we|I) can see[,\s\-—]|This means that[,\s\-—]|This shows that[,\s\-—]|This leads to[,\s\-—]|This implies[,\s\-—]|Therefore[,\s\-—]|Thus[,\s\-—]|So[,\s\-—]|In conclusion[,\s\-—]|To conclude[,\s\-—]|To sum up[,\s\-—]|Finally[,\s\-—]|Ultimately[,\s\-—])"
}

# Define structure patterns based on the four stages of reasoning in the paper
STRUCTURE_PATTERNS = {
    "problem_framing": [
        r"(let|let's) (look|analyze|break down) the problem",
        r"first, (let|I'll) (understand|analyze|restate) the problem",
        r"(the problem|we) (asks|want|need) to",
        r"(I|let me) need to (find|determine|calculate)"
    ],
    "exploration": [
        r"(let|let's) (explore|consider|think about)",
        r"(one|first|another) (approach|way|method|strategy|possibility|option|alternative) (is|would be)",
        r"I (could|can|might) (try|use|apply)",
        r"(let's|I'll) (start|begin) (by|with)",
        r"(perhaps|maybe) (we|I) (can|could|should|might)",
        r"what if (we|I) (try|consider|use|apply)",
        r"(we|I) could also (try|consider|use)",
        r"(alternatively|instead)",
        r"a different (approach|way|method) (is|would be)"
    ],
    "verification": [
        r"(let|let's) (check|verify|test|confirm|make sure)",
        r"(to|we need to) (verify|check|ensure|confirm)",
        r"(does|is) this (correct|right|make sense)",
        r"(let me|I should) (double[-\s]check|verify)"
    ],
    "synthesis": [
        r"(therefore|thus|so|hence)",
        r"(in conclusion|to conclude|in summary|summing up)",
        r"(the|my|our) (answer|solution|result) (is|would be)",
        r"(putting|bringing) (this|it|everything) together"
    ]
}

def load_api_keys(env_path=".env"):
    """
    Load API keys from .env file
    
    Args:
        env_path: Path to .env file
        
    Returns:
        Dictionary with API keys
    """
    load_dotenv(env_path)
    
    # Get API keys
    keys = {
        "openai": os.getenv("OPENAI_API_KEY"),
        "deepseek": os.getenv("DEEPSEEK_API_KEY")
    }
    
    # Validate keys
    missing_keys = [name for name, key in keys.items() if not key]
    if missing_keys:
        logging.warning(f"Missing API keys: {', '.join(missing_keys)}")
    
    return keys

def identify_pivots_with_gpt(text: str, api_key: str) -> Dict[str, List[str]]:
    """
    Use GPT-4o-mini to identify pivot statements in reasoning traces.
    
    Args:
        text: The reasoning trace text
        api_key: OpenAI API key
        
    Returns:
        Dictionary with pivot types as keys and lists of matching text as values
    """
    pivot_types_desc = {
        "realization": "Pivots that signal recognition of errors or oversights, such as 'Wait' or 'Oh'",
        "verification": "Pivots that explicitly validate intermediate hypotheses, such as 'Let me check'",
        "exploration": "Pivots that prompt the consideration of alternative solution paths, such as 'What if' or 'Another approach'",
        "integration": "Pivots that synthesize previously explored ideas into a coherent solution, such as 'Now I see how'"
    }
    
    # Format pivot types for the prompt
    pivot_types_list = "\n".join([f"- {pivot_type}: {desc}" for pivot_type, desc in pivot_types_desc.items()])
    
    # Create the API prompt
    prompt = f"""Analyze the following reasoning trace and identify specific instances of reasoning pivots.

Pivot types to identify:
{pivot_types_list}

For each pivot type found, extract the exact sentence(s) where it occurs. 
Return the results in JSON format with pivot types as keys and lists of text excerpts as values.

Reasoning trace:
{text[:8000]}  # Limit to 8000 chars to avoid token limits
"""

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    data = {
        "model": "gpt-4o-mini",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 2000,
        "temperature": 0.2,
        "response_format": {"type": "json_object"}  # Request JSON response
    }
    
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=data
        )
        response.raise_for_status()
        result = response.json()
        
        # Parse the JSON response
        content = result["choices"][0]["message"]["content"]
        pivots = json.loads(content)
        
        # Ensure all pivot types are present in the result
        for pivot_type in pivot_types_desc.keys():
            if pivot_type not in pivots:
                pivots[pivot_type] = []
                
        return pivots
    
    except Exception as e:
        logging.error(f"GPT pivot identification failed: {e}")
        # Fall back to regex approach
        return identify_pivots(text)

def identify_structures_with_gpt(text: str, api_key: str) -> Dict[str, List[str]]:
    """
    Use GPT-4o-mini to identify structure patterns in reasoning traces.
    
    Args:
        text: The reasoning trace text
        api_key: OpenAI API key
        
    Returns:
        Dictionary with structure types as keys and lists of matching text as values
    """
    structure_types_desc = {
        "problem_framing": "Restating and clarifying key aspects of the problem leading to a plan",
        "exploration": "Considering hypotheses or potential solution paths",
        "verification": "Explicit checking of hypotheses or intermediate results",
        "synthesis": "Integrating insights from earlier stages into a coherent solution"
    }
    
    # Format structure types for the prompt
    structure_types_list = "\n".join([f"- {struct_type}: {desc}" for struct_type, desc in structure_types_desc.items()])
    
    # Create the API prompt
    prompt = f"""Analyze the following reasoning trace and identify specific instances of reasoning structures.

Structure types to identify:
{structure_types_list}

For each structure type found, extract the exact sentence(s) where it occurs. 
Return the results in JSON format with structure types as keys and lists of text excerpts as values.

Reasoning trace:
{text[:8000]}  # Limit to 8000 chars to avoid token limits
"""

    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    data = {
        "model": "gpt-4o-mini",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 2000,
        "temperature": 0.2,
        "response_format": {"type": "json_object"}  # Request JSON response
    }
    
    try:
        response = requests.post(
            "https://api.openai.com/v1/chat/completions",
            headers=headers,
            json=data
        )
        response.raise_for_status()
        result = response.json()
        
        # Parse the JSON response
        content = result["choices"][0]["message"]["content"]
        structures = json.loads(content)
        
        # Ensure all structure types are present in the result
        for struct_type in structure_types_desc.keys():
            if struct_type not in structures:
                structures[struct_type] = []
                
        return structures
    
    except Exception as e:
        logging.error(f"GPT structure identification failed: {e}")
        # Fall back to regex approach
        return identify_structures(text)

def extract_thinking_trace(text: str) -> str:
    """
    Extract the thinking trace from a model's response.
    
    Args:
        text: The full response text
        
    Returns:
        The extracted thinking trace
    """
    # Look for specific formats like <thinking> or "Let me think step by step"
    thinking_patterns = [
        r"<\|begin_of_thought\|>(.*?)<\|end_of_thought\|>",
        r"<thinking>(.*?)</thinking>",
        r"<im_start>think\n(.*?)<im_end>",
        r"Let me think step by step[\.\s:]+(.*?)(?:So,? (?:to )?(?:summarize|conclude)|The answer is|Therefore,)",
        r"Let's think step by step[\.\s:]+(.*?)(?:So,? (?:to )?(?:summarize|conclude)|The answer is|Therefore,)",
        r"I'll solve this step by step[\.\s:]+(.*?)(?:So,? (?:to )?(?:summarize|conclude)|The answer is|Therefore,)",
        r"step by step[\.\s:]+(.*?)(?:So,? (?:to )?(?:summarize|conclude)|The answer is|Therefore,)",
    ]
    
    for pattern in thinking_patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    # If no pattern matched, assume the whole text before "The answer is" is the thinking
    match = re.split(r"(?:So,? (?:to )?(?:summarize|conclude)|The answer is|Therefore,)", text, 1, re.IGNORECASE)
    if len(match) > 1:
        return match[0].strip()
    
    # For the simplescaling dataset, check if "thinking_trajectories" is a list
    if isinstance(text, list) and len(text) > 0:
        return "\n".join(text)
    
    return text

def extract_answer(text: str) -> str:
    """
    Extract the final answer from a model's response.
    
    Args:
        text: The full response text
        
    Returns:
        The extracted answer
    """
    # Look for specific formats for the answer
    answer_patterns = [
        r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>",
        r"The (?:final )?answer is[:\s]+(.*?)(?:$|\.)",
        r"Therefore,[:\s]+(.*?)(?:$|\.)",
        r"So,? (?:to )?(?:summarize|conclude)[:\s]+(.*?)(?:$|\.)",
    ]
    
    for pattern in answer_patterns:
        match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
        if match:
            return match.group(1).strip()
    
    # If no pattern matched, return the last sentence as the answer
    sentences = re.split(r'(?<=[.!?])\s+', text)
    if sentences:
        return sentences[-1].strip()
    
    return ""

def identify_pivots(text: str) -> Dict[str, List[str]]:
    """
    Identify pivot points in the reasoning trace using regex patterns.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        A dictionary mapping pivot types to lists of pivot instances
    """
    pivots = defaultdict(list)
    
    # Split the text into paragraphs for more granular analysis
    paragraphs = re.split(r'\n\n|\r\n\r\n', text)
    
    for paragraph in paragraphs:
        # Check for each pivot pattern
        for pivot_type, pattern in PIVOT_PATTERNS.items():
            matches = re.finditer(pattern, paragraph, re.IGNORECASE)
            for match in matches:
                # Extract the sentence containing the pivot
                start = max(0, paragraph.rfind('.', 0, match.start()) + 1)
                end = paragraph.find('.', match.end())
                if end == -1:
                    end = len(paragraph)
                
                pivot_sentence = paragraph[start:end].strip()
                if pivot_sentence:
                    pivots[pivot_type].append(pivot_sentence)
    
    return pivots

def identify_structures(text: str) -> Dict[str, List[str]]:
    """
    Identify structural elements in the reasoning trace using regex patterns.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        A dictionary mapping structure types to lists of structure instances
    """
    structures = defaultdict(list)
    
    # Split the text into paragraphs
    paragraphs = re.split(r'\n\n|\r\n\r\n', text)
    
    for paragraph in paragraphs:
        # Check for each structure pattern
        for structure_type, patterns in STRUCTURE_PATTERNS.items():
            for pattern in patterns:
                matches = re.finditer(pattern, paragraph, re.IGNORECASE)
                for match in matches:
                    # Extract the sentence containing the structure
                    start = max(0, paragraph.rfind('.', 0, match.start()) + 1)
                    end = paragraph.find('.', match.end())
                    if end == -1:
                        end = len(paragraph)
                    
                    structure_sentence = paragraph[start:end].strip()
                    if structure_sentence:
                        structures[structure_type].append(structure_sentence)
    
    return structures

def detect_pivots(text: str) -> Tuple[Dict[str, int], Dict[str, List[str]]]:
    """
    Detect pivot types in the text and return counts and instances.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        A tuple of (pivot_counts, pivot_instances)
    """
    pivot_counts = {pivot_type: 0 for pivot_type in PIVOT_PATTERNS}
    pivot_instances = {pivot_type: [] for pivot_type in PIVOT_PATTERNS}
    
    # Get the pivots using the identify_pivots function
    pivots = identify_pivots(text)
    
    # Convert to counts and instances format
    for pivot_type, instances in pivots.items():
        pivot_counts[pivot_type] = len(instances)
        pivot_instances[pivot_type] = instances
    
    return pivot_counts, pivot_instances

def detect_structures(text: str) -> Tuple[Dict[str, int], Dict[str, List[str]]]:
    """
    Detect structural elements in the text and return counts and instances.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        A tuple of (structure_counts, structure_instances)
    """
    structure_counts = {structure: 0 for structure in STRUCTURE_PATTERNS}
    structure_instances = {structure: [] for structure in STRUCTURE_PATTERNS}
    
    # Get the structures using the identify_structures function
    structures = identify_structures(text)
    
    # Convert to counts and instances format
    for structure_type, instances in structures.items():
        structure_counts[structure_type] = len(instances)
        structure_instances[structure_type] = instances
    
    return structure_counts, structure_instances

def analyze_trace_structure(text: str, api_key: str = None, use_gpt: bool = False) -> Dict[str, Any]:
    """
    Analyze the structure of a reasoning trace.
    
    Args:
        text: The reasoning trace text
        api_key: OpenAI API key (required if use_gpt=True)
        use_gpt: Whether to use GPT for analysis
        
    Returns:
        A dictionary with structural analysis
    """
    # Split into paragraphs
    paragraphs = re.split(r'\n\n|\r\n\r\n', text)
    num_paragraphs = len(paragraphs)
    
    # Analyze paragraph lengths
    paragraph_lengths = [len(p) for p in paragraphs]
    
    # Count equations (lines that might contain mathematical expressions)
    equation_pattern = r'(?:^|\n).*?(?:\$|\\\(|\\begin\{|=|\+|-|\*|\/|\^|\})\s*.*(?:$|\n)'
    equations = re.findall(equation_pattern, text, re.MULTILINE)
    num_equations = len(equations)
    
    # Analyze the sequence of reasoning
    reasoning_sequence = []
    
    # Use GPT or regex to identify pivots
    pivots_in_text = {}
    if use_gpt and api_key:
        pivots_in_text = identify_pivots_with_gpt(text, api_key)
    else:
        pivots_in_text = identify_pivots(text)
    
    for paragraph in paragraphs:
        pivots_in_paragraph = {}
        for pivot_type in PIVOT_PATTERNS.keys():
            # Check if any pivot of this type appears in this paragraph
            if any(pivot in paragraph for pivot in pivots_in_text.get(pivot_type, [])):
                pivots_in_paragraph[pivot_type] = True
        
        if pivots_in_paragraph:
            reasoning_sequence.append(list(pivots_in_paragraph.keys()))
        else:
            reasoning_sequence.append(["computation"])  # Default if no pivot found
    
    # Calculate token length (crude approximation)
    tokens = text.split()
    token_length = len(tokens)
    
    return {
        "num_paragraphs": num_paragraphs,
        "paragraph_lengths": paragraph_lengths,
        "num_equations": num_equations,
        "reasoning_sequence": reasoning_sequence,
        "token_length": token_length
    }

def analyze_dataset(dataset_name: str, split: str = "train", num_samples: int = None, 
                    api_key: str = None, use_gpt: bool = False, output_dir: str = "analysis/results") -> Tuple[pd.DataFrame, Dict[str, Any]]:
    """
    Analyze reasoning traces in a dataset.
    
    Args:
        dataset_name: The name of the dataset on Hugging Face
        split: The split to analyze
        num_samples: Number of samples to analyze (None = all)
        api_key: OpenAI API key (required if use_gpt=True)
        use_gpt: Whether to use GPT for analysis
        output_dir: Directory to save results
        
    Returns:
        A DataFrame with analysis results for each example,
        A dictionary with aggregated statistics
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Load dataset
    try:
        dataset = load_dataset(dataset_name, split=split)
        logging.info(f"Loaded {len(dataset)} examples from {dataset_name} ({split})")
    except Exception as e:
        logging.error(f"Error loading dataset {dataset_name}: {e}")
        return pd.DataFrame(), {}
    
    # Sample if needed
    if num_samples and num_samples < len(dataset):
        indices = random.sample(range(len(dataset)), num_samples)
        dataset = dataset.select(indices)
        logging.info(f"Sampled {len(dataset)} examples for analysis")
    
    # Check the dataset structure
    col_names = dataset.column_names
    is_openthoughts = "conversations" in col_names
    is_simplescaling = "thinking_trajectories" in col_names
    has_thinking_col = "thinking" in col_names
    has_response_col = "response" in col_names or "answer" in col_names
    has_reasoning_col = "reasoning" in col_names
    
    if not (is_openthoughts or is_simplescaling or has_thinking_col or has_response_col or has_reasoning_col):
        logging.error(f"Unsupported dataset format. Columns: {col_names}")
        return pd.DataFrame(), {}
    
    results = []
    
    for i, example in enumerate(tqdm(dataset, desc="Analyzing traces")):
        # Extract question, trace, and answer
        question = ""
        trace = ""
        answer = ""
        is_correct = False
        
        # Extract from various dataset formats
        if is_openthoughts:
            # OpenThoughts format with conversations
            if "conversations" in example and len(example["conversations"]) > 0:
                for conv in example["conversations"]:
                    if conv.get("from") == "user":
                        question = conv.get("value", "")
                    elif conv.get("from") == "assistant":
                        assistant_response = conv.get("value", "")
                        trace = extract_thinking_trace(assistant_response)
                        answer = extract_answer(assistant_response)
        elif is_simplescaling:
            # Simplescaling dataset format
            if "question" in col_names:
                question = example["question"]
            
            # Extract thinking trajectories
            if "thinking_trajectories" in col_names and example["thinking_trajectories"]:
                if isinstance(example["thinking_trajectories"], list) and len(example["thinking_trajectories"]) > 0:
                    trace = "\n".join(example["thinking_trajectories"])
                else:
                    trace = str(example["thinking_trajectories"])
            
            # Extract solution/answer
            if "solution" in col_names:
                answer = example["solution"]
            
            # Extract COT type if available (for later analysis)
            cot_type = example.get("cot_type", "")
        else:
            # Standard format
            example_id = example.get("id", str(i))
            
            # Extract question
            if "question" in col_names:
                question = example["question"]
            elif "problem" in col_names:
                question = example["problem"]
            
            # Extract thinking trace
            if has_thinking_col:
                trace = example["thinking"]
            elif has_reasoning_col:
                trace = example["reasoning"]
            elif has_response_col:
                response_col = "response" if "response" in col_names else "answer"
                trace = extract_thinking_trace(example[response_col])
            
            # Extract answer
            if "answer" in col_names:
                answer = example["answer"]
            elif "response" in col_names:
                answer = extract_answer(example["response"])
            
            # Check if answer is correct
            if "is_correct" in col_names:
                is_correct = example["is_correct"]
            elif "correct" in col_names:
                is_correct = example["correct"]
        
        # Skip if trace is empty
        if not trace:
            continue
        
        # Analyze trace
        # Detect pivots
        pivot_counts, pivot_instances = detect_pivots(trace)
        
        # Detect structures
        structure_counts, structure_instances = detect_structures(trace)
        
        # Analyze trace structure
        structure_analysis = analyze_trace_structure(trace, api_key, use_gpt)
        
        # Create result entry
        result = {
            "example_id": example.get("id", str(i)),
            "question": question,
            "answer": answer,
            "is_correct": is_correct,
            "trace_length": len(trace),
            "token_count": structure_analysis["token_length"],
            "num_paragraphs": structure_analysis["num_paragraphs"],
            "num_equations": structure_analysis["num_equations"],
            "pivot_diversity": sum(1 for count in pivot_counts.values() if count > 0),
            "structure_diversity": sum(1 for count in structure_counts.values() if count > 0),
            "has_pivot": any(count > 0 for count in pivot_counts.values()),
        }
        
        # Add pivot counts
        for pivot_type, count in pivot_counts.items():
            result[f"pivot_{pivot_type}"] = count
        
        # Add total pivots
        result["total_pivots"] = sum(pivot_counts.values())
        
        # Add structure counts
        for structure_type, count in structure_counts.items():
            result[f"structure_{structure_type}"] = count
        
        # Add total structures
        result["total_structures"] = sum(structure_counts.values())
        
        results.append(result)
    
    # Create DataFrame
    df = pd.DataFrame(results)
    
    # Calculate aggregate statistics
    stats = {
        "total_samples": len(df),
        "avg_trace_length": df["trace_length"].mean() if not df.empty else 0,
        "avg_token_count": df["token_count"].mean() if not df.empty else 0,
        "avg_paragraphs": df["num_paragraphs"].mean() if not df.empty else 0,
        "avg_equations": df["num_equations"].mean() if not df.empty else 0,
        "correct_answers": sum(df["is_correct"]) if "is_correct" in df.columns and not df.empty else 0,
        "incorrect_answers": len(df) - sum(df["is_correct"]) if "is_correct" in df.columns and not df.empty else 0,
        "avg_total_pivots": df["total_pivots"].mean() if not df.empty else 0,
        "avg_total_structures": df["total_structures"].mean() if not df.empty else 0,
        "avg_pivot_diversity": df["pivot_diversity"].mean() if not df.empty else 0,
        "avg_structure_diversity": df["structure_diversity"].mean() if not df.empty else 0,
    }
    
    # Add pivot type statistics
    for pivot_type in PIVOT_PATTERNS.keys():
        pivot_col = f"pivot_{pivot_type}"
        if pivot_col in df.columns and not df.empty:
            stats[f"avg_{pivot_col}"] = df[pivot_col].mean()
            stats[f"examples_with_{pivot_type}"] = (df[pivot_col] > 0).sum()
            stats[f"pct_with_{pivot_type}"] = (df[pivot_col] > 0).mean() * 100
            
            if "is_correct" in df.columns:
                stats[f"{pivot_type}_in_correct"] = df[df["is_correct"]][pivot_col].mean()
                stats[f"{pivot_type}_in_incorrect"] = df[~df["is_correct"]][pivot_col].mean()
    
    # Add structure type statistics
    for structure_type in STRUCTURE_PATTERNS.keys():
        structure_col = f"structure_{structure_type}"
        if structure_col in df.columns and not df.empty:
            stats[f"avg_{structure_col}"] = df[structure_col].mean()
            stats[f"examples_with_{structure_type}"] = (df[structure_col] > 0).sum()
            stats[f"pct_with_{structure_type}"] = (df[structure_col] > 0).mean() * 100
            
            if "is_correct" in df.columns:
                stats[f"{structure_type}_in_correct"] = df[df["is_correct"]][structure_col].mean()
                stats[f"{structure_type}_in_incorrect"] = df[~df["is_correct"]][structure_col].mean()
    
    # Save results
    df.to_csv(os.path.join(output_dir, "trace_analysis_results.csv"), index=False)
    
    with open(os.path.join(output_dir, "trace_analysis_stats.json"), "w") as f:
        json.dump(stats, f, indent=2, cls=NumpyEncoder)
    
    return df, stats

def generate_pivot_visualizations(df: pd.DataFrame, stats: Dict[str, Any], output_dir: str):
    """Generate visualizations of pivot analysis"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get pivot types - only include the four specific pivot types (not derived columns like diversity)
    pivot_cols = [col for col in df.columns if col.startswith("pivot_") and col.replace("pivot_", "") in PIVOT_PATTERNS.keys()]
    pivot_types = [col.replace("pivot_", "") for col in pivot_cols]
    
    # 1. Pivot frequency bar chart
    plt.figure(figsize=(12, 6))
    pivot_means = df[pivot_cols].mean().sort_values(ascending=False)
    
    # Clean column names for display
    pivot_means.index = [col.replace("pivot_", "") for col in pivot_means.index]
    
    sns.barplot(x=pivot_means.index, y=pivot_means.values, palette="viridis")
    plt.title("Average Frequency of Pivot Types in Reasoning Traces")
    plt.xlabel("Pivot Type")
    plt.ylabel("Average Occurrences")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "pivot_frequency.png"), dpi=300)
    plt.close()
    
    # 2. Pivot presence percentage
    plt.figure(figsize=(12, 6))
    pivot_presence = (df[pivot_cols] > 0).mean() * 100
    pivot_presence = pivot_presence.sort_values(ascending=False)
    
    # Clean column names for display
    pivot_presence.index = [col.replace("pivot_", "") for col in pivot_presence.index]
    
    sns.barplot(x=pivot_presence.index, y=pivot_presence.values, palette="muted")
    plt.title("Percentage of Traces Containing Each Pivot Type")
    plt.xlabel("Pivot Type")
    plt.ylabel("Percentage Present (%)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "pivot_presence.png"), dpi=300)
    plt.close()
    
    # 3. Pivot correlation with success (if we have correctness data)
    if "is_correct" in df.columns and df["is_correct"].nunique() > 1:
        plt.figure(figsize=(12, 6))
        
        pivot_success = []
        for pivot_type in pivot_types:
            col = f"pivot_{pivot_type}"
            success_rate = df[df[col] > 0]["is_correct"].mean() * 100
            pivot_success.append((pivot_type, success_rate))
        
        pivot_success.sort(key=lambda x: x[1], reverse=True)
        
        plt.bar([x[0] for x in pivot_success], [x[1] for x in pivot_success])
        plt.title("Success Rate When Pivot Type is Present")
        plt.xlabel("Pivot Type")
        plt.ylabel("Success Rate (%)")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "pivot_success_rate.png"), dpi=300)
        plt.close()
    
    # 4. Pivot diversity histogram
    plt.figure(figsize=(10, 6))
    sns.histplot(df["pivot_diversity"], kde=True, bins=range(len(pivot_types)+2))
    plt.title("Distribution of Pivot Diversity in Reasoning Traces")
    plt.xlabel("Number of Different Pivot Types")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "pivot_diversity.png"), dpi=300)
    plt.close()

def generate_structure_visualizations(df: pd.DataFrame, stats: Dict[str, Any], output_dir: str):
    """Generate visualizations of structure analysis"""
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Get structure types - only include the four specific structure types
    structure_cols = [col for col in df.columns if col.startswith("structure_") and col.replace("structure_", "") in STRUCTURE_PATTERNS.keys()]
    structure_types = [col.replace("structure_", "") for col in structure_cols]
    
    # 1. Structure frequency bar chart
    plt.figure(figsize=(12, 6))
    structure_means = df[structure_cols].mean().sort_values(ascending=False)
    
    # Clean column names for display
    structure_means.index = [col.replace("structure_", "") for col in structure_means.index]
    
    sns.barplot(x=structure_means.index, y=structure_means.values, palette="viridis")
    plt.title("Average Frequency of Structure Types in Reasoning Traces")
    plt.xlabel("Structure Type")
    plt.ylabel("Average Occurrences")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "structure_frequency.png"), dpi=300)
    plt.close()
    
    # 2. Structure presence percentage
    plt.figure(figsize=(12, 6))
    structure_presence = (df[structure_cols] > 0).mean() * 100
    structure_presence = structure_presence.sort_values(ascending=False)
    
    # Clean column names for display
    structure_presence.index = [col.replace("structure_", "") for col in structure_presence.index]
    
    sns.barplot(x=structure_presence.index, y=structure_presence.values, palette="muted")
    plt.title("Percentage of Traces Containing Each Structure Type")
    plt.xlabel("Structure Type")
    plt.ylabel("Percentage Present (%)")
    plt.xticks(rotation=45, ha="right")
    plt.tight_layout()
    plt.savefig(os.path.join(output_dir, "structure_presence.png"), dpi=300)
    plt.close()
    
    # 3. Structure correlation with success (if we have correctness data)
    if "is_correct" in df.columns and df["is_correct"].nunique() > 1:
        plt.figure(figsize=(12, 6))
        
        structure_success = []
        for structure_type in structure_types:
            col = f"structure_{structure_type}"
            success_rate = df[df[col] > 0]["is_correct"].mean() * 100
            structure_success.append((structure_type, success_rate))
        
        structure_success.sort(key=lambda x: x[1], reverse=True)
        
        plt.bar([x[0] for x in structure_success], [x[1] for x in structure_success])
        plt.title("Success Rate When Structure Type is Present")
        plt.xlabel("Structure Type")
        plt.ylabel("Success Rate (%)")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.savefig(os.path.join(output_dir, "structure_success_rate.png"), dpi=300)
        plt.close()

def generate_summary_report(df: pd.DataFrame, stats: Dict[str, Any], output_dir: str) -> str:
    """
    Generate a comprehensive summary report of the analysis.
    
    Args:
        df: DataFrame with analysis results
        stats: Dictionary with aggregated statistics
        output_dir: Directory to save report
        
    Returns:
        Path to the saved report
    """
    report = """# Reasoning Trace Analysis Report

## Summary

This report summarizes the analysis of reasoning traces, focusing on the four pivot types and four reasoning stages described in the paper.

### Key Findings

- **Total Samples Analyzed**: {total_samples}
- **Average Trace Length**: {avg_trace_length:.1f} characters
- **Average Token Count**: {avg_token_count:.1f} tokens
- **Average Paragraphs**: {avg_paragraphs:.1f} paragraphs
- **Average Equations**: {avg_equations:.1f} equations
- **Average Total Pivots**: {avg_total_pivots:.2f} per trace

## Pivot Analysis

### Pivot Diversity

- **Average Pivot Diversity**: {avg_pivot_diversity:.2f} distinct pivot types per trace
- **Traces with at least 3 pivot types**: {pivot_3plus:.1f}%

### Pivot Type Frequencies

| Pivot Type | Average Occurrences | Percentage Present |
|------------|---------------------|-------------------|
""".format(
        total_samples=stats["total_samples"],
        avg_trace_length=stats["avg_trace_length"],
        avg_token_count=stats["avg_token_count"],
        avg_paragraphs=stats["avg_paragraphs"],
        avg_equations=stats["avg_equations"],
        avg_pivot_diversity=stats["avg_pivot_diversity"],
        avg_total_pivots=stats["avg_total_pivots"],
        pivot_3plus=(df["pivot_diversity"] >= 3).mean() * 100
    )
    
    # Add pivot type statistics
    for pivot_type in PIVOT_PATTERNS.keys():
        avg_col = f"avg_pivot_{pivot_type}"
        pct_col = f"pct_with_{pivot_type}"
        if avg_col in stats and pct_col in stats:
            report += f"| {pivot_type} | {stats[avg_col]:.2f} | {stats[pct_col]:.1f}% |\n"
    
    report += """
## Structure Analysis

### Structure Type Frequencies

| Structure Type | Average Occurrences | Percentage Present |
|----------------|---------------------|-------------------|
"""
    
    # Add structure type statistics
    for structure_type in STRUCTURE_PATTERNS.keys():
        avg_col = f"avg_structure_{structure_type}"
        pct_col = f"pct_with_{structure_type}"
        if avg_col in stats and pct_col in stats:
            report += f"| {structure_type} | {stats[avg_col]:.2f} | {stats[pct_col]:.1f}% |\n"
    
    # Add information about correlations if there is correctness data
    if "is_correct" in df.columns and df["is_correct"].nunique() > 1:
        correct_pct = df["is_correct"].mean() * 100
        report += f"""
## Correctness Analysis

- **Correct Answers**: {correct_pct:.1f}% of traces

### Pivot Types and Correctness

| Pivot Type | Frequency in Correct | Frequency in Incorrect | Difference |
|------------|----------------------|------------------------|------------|
"""
        
        for pivot_type in PIVOT_PATTERNS.keys():
            correct_col = f"{pivot_type}_in_correct"
            incorrect_col = f"{pivot_type}_in_incorrect"
            if correct_col in stats and incorrect_col in stats:
                diff = stats[correct_col] - stats[incorrect_col]
                report += f"| {pivot_type} | {stats[correct_col]:.2f} | {stats[incorrect_col]:.2f} | {diff:.2f} |\n"
    
    # Add insights section
    report += """
## Key Insights

1. **Pivot Diversity**: Successful reasoning traces tend to use multiple pivot types, with the majority containing at least three distinct pivot categories.

2. **Structure Alignment**: The traces follow the cognitive science framework of problem framing, exploration, verification, and synthesis.

3. **Metacognitive Transitions**: The presence of explicit pivots signals non-linear, reflective thinking that characterizes effective reasoning.
"""

    # Save the report
    report_path = os.path.join(output_dir, "trace_analysis_report.md")
    with open(report_path, "w") as f:
        f.write(report)
    
    return report_path

def analyze_traces(dataset_name="simplescaling/data_ablation_full59K", sample_size=1000, output_dir="analysis/results_simplescaling", use_gpt=False, api_key=None):
    """
    Analyze reasoning traces from a dataset as specified in the paper.
    This analysis identifies the four pivot types (realization, verification, exploration, integration)
    and four reasoning stages (problem framing, exploration, verification, synthesis).
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Run the comprehensive analysis
    df, stats = analyze_dataset(dataset_name=dataset_name, 
                               num_samples=sample_size, 
                               api_key=api_key, 
                               use_gpt=use_gpt, 
                               output_dir=output_dir)
    
    if df.empty:
        logging.error("No data to analyze")
        return None
    
    # Generate visualizations
    generate_pivot_visualizations(df, stats, output_dir)
    generate_structure_visualizations(df, stats, output_dir)
    
    # Print summary
    print("\nAnalysis Summary:")
    print(f"Total samples analyzed: {len(df)}")
    print(f"Average trace length: {df['trace_length'].mean():.1f} characters")
    print(f"Average token count: {df['token_count'].mean():.1f}")
    print(f"Average number of pivots: {df['total_pivots'].mean():.2f}")
    
    # Most common pivot types - only include the four specific pivot types
    pivot_cols = [col for col in df.columns if col.startswith("pivot_") and col.replace("pivot_", "") in PIVOT_PATTERNS.keys()]
    if pivot_cols:
        top_pivots = df[pivot_cols].mean().sort_values(ascending=False)
        print("\nPivot type frequencies:")
        for pivot, val in top_pivots.items():
            pivot_name = pivot.replace("pivot_", "")
            print(f"- {pivot_name}: {val:.2f} average occurrences")
    
    # Show percentage of traces with multiple pivot types
    if pivot_cols:
        pivot_diversity_3plus = (df["pivot_diversity"] >= 3).mean() * 100
        print(f"\nPercentage of traces with at least 3 pivot types: {pivot_diversity_3plus:.1f}%")
    
    print(f"\nFull results saved to {output_dir}")
    
    # Generate summary report
    report_path = generate_summary_report(df, stats, output_dir)
    print(f"\nComprehensive report generated: {report_path}")
    
    return stats

def main():
    """Main entry point"""
    parser = argparse.ArgumentParser(description="Analyze reasoning traces in datasets")
    parser.add_argument("--dataset", type=str, default="simplescaling/data_ablation_full59K",
                        help="Dataset name on Hugging Face")
    parser.add_argument("--split", type=str, default="train",
                        help="Dataset split to analyze")
    parser.add_argument("--num_samples", type=int, default=1000,
                        help="Number of samples to analyze (default: 1000 as specified in the paper)")
    parser.add_argument("--output_dir", type=str, default="analysis/results_simplescaling",
                        help="Directory to save results")
    parser.add_argument("--use_gpt", action="store_true",
                        help="Use GPT to identify pivots and structures")
    parser.add_argument("--env_file", type=str, default=".env",
                        help="Path to .env file with API keys")
    
    args = parser.parse_args()
    
    # Load API keys if using GPT
    api_key = None
    if args.use_gpt:
        keys = load_api_keys(args.env_file)
        api_key = keys.get("openai")
        if not api_key:
            logging.error("OpenAI API key required for GPT analysis")
            return
    
    # Run analysis
    analyze_traces(
        dataset_name=args.dataset,
        sample_size=args.num_samples, 
        output_dir=args.output_dir,
        use_gpt=args.use_gpt,
        api_key=api_key
    )

if __name__ == "__main__":
    main() 