"""
Script for extracting pivots from reasoning traces and generating synthetic reasoning traces.
This module provides functionality to identify different types of reasoning pivots, 
structures, and templates for generating synthetic traces.
"""

import os
import sys
import re
import json
import random
import logging
from typing import Dict, List, Any, Optional, Tuple, Set
import spacy
import numpy as np
from tqdm import tqdm
from collections import Counter, defaultdict

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Try to load spaCy model, fall back to simpler processing if not available
try:
    nlp = spacy.load("en_core_web_sm")
    SPACY_AVAILABLE = True
except:
    logging.warning("spaCy model not available. Using simpler text processing.")
    SPACY_AVAILABLE = False

# Pivot type patterns
PIVOT_PATTERNS = {
    "realization": [
        r"(?i)(I realize|I now realize|I see|I now see|Oh,|Ah,|Hmm,|Actually,|In fact,|Wait,|I notice|I observe|I discover|I find that)",
        r"(?i)(Upon reflection|Upon further reflection|On second thought|Thinking about this more|Looking more carefully)",
        r"(?i)(This means|This implies|This suggests|This indicates|This shows|This reveals)",
        r"(?i)(It dawns on me|It becomes clear|It becomes obvious|It is apparent|It is evident)"
    ],
    "verification": [
        r"(?i)(Let me verify|Let me check|Let me confirm|Let's check|Let's verify|Let's make sure|Let's confirm|I should verify|I should check)",
        r"(?i)(To verify|To check|To confirm|To ensure|To validate|To double-check)",
        r"(?i)(We can verify|We can check|We can confirm|We can ensure|We can validate)",
        r"(?i)(Let's test|Let me test|I will test|I should test|We can test|To test)"
    ],
    "decomposition": [
        r"(?i)(Let me break|Let's break|I'll break|We can break|Breaking|To break)",
        r"(?i)(Let me split|Let's split|I'll split|We can split|Splitting|To split)",
        r"(?i)(Let me divide|Let's divide|I'll divide|We can divide|Dividing|To divide)",
        r"(?i)(Let me decompose|Let's decompose|I'll decompose|We can decompose|Decomposing|To decompose)",
        r"(?i)(Let me approach|Let's approach|I'll approach|We can approach|Approaching|To approach|tackle this)"
    ],
    "exploration": [
        r"(?i)(Let me try|Let's try|I'll try|We can try|Let me explore|Let's explore|I'll explore|We can explore)",
        r"(?i)(Let me see if|Let's see if|I'll see if|We can see if|Let me investigate|Let's investigate|I'll investigate)",
        r"(?i)(Let's consider|Let me consider|I'll consider|We can consider|I wonder if|What if|Maybe)",
        r"(?i)(Let me attempt|Let's attempt|I'll attempt|We can attempt|Trying|Exploring|Considering|Investigating)"
    ],
    "contradiction": [
        r"(?i)(That's not right|This can't be right|This doesn't make sense|This is incorrect|This is wrong|This contradicts)",
        r"(?i)(I made a mistake|I made an error|That was an error|That was a mistake|That's a mistake|That's an error)",
        r"(?i)(This contradicts|This conflicts with|This is inconsistent with|This doesn't match|This doesn't agree with)",
        r"(?i)(Wait, that's wrong|Actually, that's incorrect|No, that's not right|I need to correct myself)"
    ],
    "variable_assignment": [
        r"(?i)(Let's define|Let me define|Let's denote|Let me denote|Let's set|Let me set|Let's call|Let me call)",
        r"(?i)(We define|I define|We denote|I denote|We set|I set|We call|I call)",
        r"(?i)(Let's use|Let me use|We use|I use|Let|In this case|Here)",
        r"(?i)(Let's say|Let me say|We say|I say|Suppose|Assuming|Given that)"
    ],
    "proof_by_cases": [
        r"(?i)(Let's consider the cases|Let me consider the cases|I'll consider the cases|We can consider the cases)",
        r"(?i)(Let's analyze the cases|Let me analyze the cases|I'll analyze the cases|We can analyze the cases)",
        r"(?i)(Case 1|Case 2|Case 3|Case I|Case II|Case III|First case|Second case|Third case)",
        r"(?i)(In the first case|In the second case|In the third case|For the first case|For the second case|For the third case)"
    ],
    "insight": [
        r"(?i)(The key insight|The main insight|The crucial insight|The important insight|The fundamental insight)",
        r"(?i)(The key observation|The main observation|The crucial observation|The important observation)",
        r"(?i)(The key idea|The main idea|The crucial idea|The important idea|The fundamental idea)",
        r"(?i)(The key is|The main point is|The crucial aspect is|The important part is|The fundamental issue is)"
    ],
}

# Structure patterns
STRUCTURE_PATTERNS = {
    "synthesis": [
        r"(?i)(Putting (this|it|these|those|all this|all of this) together|Combining (this|these|those))",
        r"(?i)(In summary|To summarize|Summarizing|In conclusion|To conclude|Concluding|To wrap up)",
        r"(?i)(Therefore|Thus|Hence|So|Consequently|As a result|Accordingly|It follows that)",
        r"(?i)(Finally, we|Ultimately, we|Eventually, we|In the end, we|We can finally)"
    ],
    "step_by_step": [
        r"(?i)(Step 1|Step 2|Step 3|Step one|Step two|Step three|First step|Second step|Third step)",
        r"(?i)(First, |Second, |Third, |Next, |Then, |Afterwards, |Subsequently, |Following this, )",
        r"(?i)(To begin|Let's start|I'll start|We start|We begin|I begin)",
        r"(?i)(The first step|The next step|The last step|The final step)"
    ],
    "reasoning_by_analogy": [
        r"(?i)(This is (similar|analogous|comparable|akin|equivalent) to|This resembles|This is like)",
        r"(?i)(By analogy|Using an analogy|Drawing an analogy|Making an analogy|Through analogy)",
        r"(?i)(Similarly|Likewise|In the same way|In a similar manner|Comparably)",
        r"(?i)(We can draw a parallel|We can make a comparison|We can relate this to)"
    ],
    "definition_application": [
        r"(?i)(By definition|According to the definition|From the definition|Using the definition)",
        r"(?i)(Applying the definition|The definition states|The definition says|The definition tells us)",
        r"(?i)(The definition of|As defined|As per the definition|Based on the definition)",
        r"(?i)(Using the concept of|Applying the concept of|According to the concept of)"
    ],
    "property_application": [
        r"(?i)(Using the property|Applying the property|By the property|According to the property)",
        r"(?i)(Due to the property|Because of the property|Thanks to the property)",
        r"(?i)(The property of|The properties of|This property|These properties)",
        r"(?i)(The [a-z]+ property|The rule of|The [a-z]+ rule|The [a-z]+ law)"
    ],
    "calculation": [
        r"(?i)(Computing|Computing this|Calculating|Calculating this|Evaluating|Evaluating this)",
        r"(?i)(Let me compute|Let me calculate|Let me evaluate|Let's compute|Let's calculate|Let's evaluate)",
        r"(?i)(If we compute|If we calculate|If we evaluate|When we compute|When we calculate|When we evaluate)",
        r"(?i)(The computation|The calculation|The evaluation|This computation|This calculation|This evaluation)"
    ],
    "problem_framing": [
        r"(?i)(The problem asks|The problem is asking|We are asked|We need to|We want to|The goal is to)",
        r"(?i)(The question is|The task is|The challenge is|The objective is|The aim is|The purpose is)",
        r"(?i)(We're looking for|I'm looking for|We need to find|I need to find|We need to determine|I need to determine)",
        r"(?i)(In this problem|For this problem|This problem|To solve this|To tackle this|To address this)"
    ],
}

def extract_pivots(text: str) -> Dict[str, List[str]]:
    """
    Extract pivot statements from a reasoning trace.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        Dictionary with pivot types as keys and lists of matching text as values
    """
    pivots = {pivot_type: [] for pivot_type in PIVOT_PATTERNS}
    
    # Break text into paragraphs for context
    paragraphs = text.split("\n\n")
    
    for pivot_type, patterns in PIVOT_PATTERNS.items():
        for pattern in patterns:
            for paragraph in paragraphs:
                matches = re.finditer(pattern, paragraph)
                for match in matches:
                    start, end = match.span()
                    
                    # Get some context (try to get the full sentence)
                    context_start = max(0, start - 20)
                    context_end = min(len(paragraph), end + 100)
                    context = paragraph[context_start:context_end].strip()
                    
                    # Avoid duplicates by checking if this context is already covered
                    if not any(context in existing for existing in pivots[pivot_type]):
                        pivots[pivot_type].append(context)
    
    return pivots

def extract_structures(text: str) -> Dict[str, List[str]]:
    """
    Extract structural patterns from a reasoning trace.
    
    Args:
        text: The reasoning trace text
        
    Returns:
        Dictionary with structure types as keys and lists of matching text as values
    """
    structures = {structure_type: [] for structure_type in STRUCTURE_PATTERNS}
    
    # Break text into paragraphs for context
    paragraphs = text.split("\n\n")
    
    for structure_type, patterns in STRUCTURE_PATTERNS.items():
        for pattern in patterns:
            for paragraph in paragraphs:
                matches = re.finditer(pattern, paragraph)
                for match in matches:
                    start, end = match.span()
                    
                    # Get some context (try to get the full sentence)
                    context_start = max(0, start - 20)
                    context_end = min(len(paragraph), end + 100)
                    context = paragraph[context_start:context_end].strip()
                    
                    # Avoid duplicates by checking if this context is already covered
                    if not any(context in existing for existing in structures[structure_type]):
                        structures[structure_type].append(context)
    
    return structures

def get_pivot_profile(pivots: Dict[str, List[str]]) -> str:
    """
    Generate a profile of pivot types based on their presence in the trace.
    
    Args:
        pivots: Dictionary of pivot types and matching instances
        
    Returns:
        String representation of pivot profile (comma-separated pivot types)
    """
    # Include only pivot types that are present
    profile = []
    for pivot_type, instances in pivots.items():
        if instances:
            profile.append(pivot_type)
    
    return ", ".join(sorted(profile))

def get_structure_profile(structures: Dict[str, List[str]]) -> str:
    """
    Generate a profile of structure types based on their presence in the trace.
    
    Args:
        structures: Dictionary of structure types and matching instances
        
    Returns:
        String representation of structure profile (comma-separated structure types)
    """
    # Include only structure types that are present
    profile = []
    for structure_type, instances in structures.items():
        if instances:
            profile.append(structure_type)
    
    return ", ".join(sorted(profile))

def analyze_trace(trace: str) -> Dict[str, Any]:
    """
    Analyze a reasoning trace for pivots and structures.
    
    Args:
        trace: The reasoning trace text
        
    Returns:
        Dictionary with analysis results
    """
    # Count tokens
    token_count = len(trace.split())
    
    # Count paragraphs
    paragraph_count = len(trace.split("\n\n"))
    
    # Count equations (simple heuristic)
    equation_count = len(re.findall(r'\\boxed{', trace)) + len(re.findall(r'=', trace))
    
    # Extract pivots and structures
    pivots = extract_pivots(trace)
    structures = extract_structures(trace)
    
    # Count pivot instances
    pivot_counts = {pivot_type: len(instances) for pivot_type, instances in pivots.items()}
    total_pivots = sum(pivot_counts.values())
    
    # Count structure instances
    structure_counts = {structure_type: len(instances) for structure_type, instances in structures.items()}
    total_structures = sum(structure_counts.values())
    
    # Generate profiles
    pivot_profile = get_pivot_profile(pivots)
    structure_profile = get_structure_profile(structures)
    
    analysis = {
        "token_count": token_count,
        "paragraph_count": paragraph_count,
        "equation_count": equation_count,
        "pivot_counts": pivot_counts,
        "total_pivots": total_pivots,
        "structure_counts": structure_counts,
        "total_structures": total_structures,
        "pivot_profile": pivot_profile,
        "structure_profile": structure_profile,
        "pivots": pivots,
        "structures": structures
    }
    
    return analysis

def identify_templates(analyses: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    """
    Identify common templates from a list of trace analyses.
    
    Args:
        analyses: List of trace analysis results
        
    Returns:
        List of template dictionaries with counts and profiles
    """
    # Group by pivot profile
    profile_groups = defaultdict(list)
    for analysis in analyses:
        profile_groups[analysis["pivot_profile"]].append(analysis)
    
    templates = []
    for profile, group in profile_groups.items():
        if len(group) < 2:  # Skip singleton profiles
            continue
        
        # Calculate statistics for this template
        token_counts = [a["token_count"] for a in group]
        paragraph_counts = [a["paragraph_count"] for a in group]
        equation_counts = [a["equation_count"] for a in group]
        
        # Calculate average pivot counts
        pivot_counts = defaultdict(list)
        for analysis in group:
            for pivot_type, count in analysis["pivot_counts"].items():
                pivot_counts[pivot_type].append(count)
        
        avg_pivot_counts = {
            pivot_type: sum(counts) / len(counts) 
            for pivot_type, counts in pivot_counts.items() 
            if counts
        }
        
        # Calculate average structure counts
        structure_counts = defaultdict(list)
        for analysis in group:
            for structure_type, count in analysis["structure_counts"].items():
                structure_counts[structure_type].append(count)
        
        avg_structure_counts = {
            structure_type: sum(counts) / len(counts) 
            for structure_type, counts in structure_counts.items() 
            if counts
        }
        
        template = {
            "pivot_profile": profile,
            "count": len(group),
            "avg_token_count": sum(token_counts) / len(token_counts),
            "avg_paragraph_count": sum(paragraph_counts) / len(paragraph_counts),
            "avg_equation_count": sum(equation_counts) / len(equation_counts),
            "avg_pivot_counts": avg_pivot_counts,
            "avg_structure_counts": avg_structure_counts,
            "example_idx": [analyses.index(a) for a in group[:3]]  # Store indices of examples
        }
        
        templates.append(template)
    
    # Sort templates by count (most common first)
    templates.sort(key=lambda x: x["count"], reverse=True)
    
    return templates

def create_template_prompt(template: Dict[str, Any], question: str, domain: str = "general") -> str:
    """
    Create a prompt for generating a synthetic trace based on a template.
    
    Args:
        template: Template dictionary with pivot and structure information
        question: The question to generate a trace for
        domain: Problem domain (e.g., 'math', 'logic', 'physics')
        
    Returns:
        Prompt text for generating a trace
    """
    # Extract pivot and structure types from the template
    pivot_types = template["pivot_profile"].split(", ")
    
    # Get the average counts of pivots and structures
    pivot_counts = template["avg_pivot_counts"]
    structure_counts = template["avg_structure_counts"]
    
    # Create suggestions for pivot counts
    pivot_suggestions = []
    for pivot_type in pivot_types:
        if pivot_type in pivot_counts and pivot_counts[pivot_type] >= 1:
            count = int(round(pivot_counts[pivot_type]))
            pivot_suggestions.append(f"- {count} {pivot_type} pivot{'s' if count > 1 else ''}")
    
    # Create suggestions for structure types and counts
    structure_suggestions = []
    for structure_type, count in structure_counts.items():
        if count >= 1:
            count_int = int(round(count))
            structure_suggestions.append(f"- {count_int} {structure_type} structure{'s' if count_int > 1 else ''}")
    
    # Create the prompt
    prompt = f"""You are an expert in solving {domain} problems and explaining your reasoning in a clear, step-by-step manner.

Please solve the following problem:

{question}

In your solution, include your thinking process with approximately:
{', '.join(pivot_suggestions)}

Also, structure your solution with:
{', '.join(structure_suggestions)}

Your solution should be approximately {int(template["avg_paragraph_count"])} paragraphs long and should explain your thinking thoroughly, showing all your work and calculations.

First frame the problem, then work through it methodically, and finally provide a clearly marked final answer."""
    
    return prompt

def analyze_traces_in_dataset(dataset_path: str, output_dir: str, sample_limit: int = None) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
    """
    Analyze reasoning traces in a dataset and identify templates.
    
    Args:
        dataset_path: Path to the dataset
        output_dir: Directory to save analysis results
        sample_limit: Optional limit on the number of samples to analyze
        
    Returns:
        Tuple of (trace analyses, templates)
    """
    from datasets import load_dataset
    
    logging.info(f"Loading dataset from {dataset_path}")
    dataset = load_dataset(dataset_path)
    
    if isinstance(dataset, dict):
        dataset = dataset["train"]
    
    if sample_limit:
        dataset = dataset.select(range(min(len(dataset), sample_limit)))
    
    logging.info(f"Analyzing {len(dataset)} traces")
    
    # Create output directory
    os.makedirs(output_dir, exist_ok=True)
    
    # Analyze traces
    analyses = []
    for i, example in enumerate(tqdm(dataset, desc="Analyzing traces")):
        # Get the thinking trace
        if "thinking" in example:
            trace = example["thinking"]
        elif "thinking_trajectories" in example and example["thinking_trajectories"]:
            trace = example["thinking_trajectories"][0]
        else:
            logging.warning(f"No thinking trace found in example {i}")
            continue
        
        # Analyze the trace
        analysis = analyze_trace(trace)
        analysis["id"] = example.get("id", f"example_{i}")
        analysis["question"] = example.get("question", "")
        
        analyses.append(analysis)
    
    # Identify templates
    templates = identify_templates(analyses)
    
    # Save analyses
    with open(os.path.join(output_dir, "trace_analyses.json"), "w") as f:
        # Convert objects that aren't JSON serializable
        serializable_analyses = []
        for analysis in analyses:
            serializable = {k: v for k, v in analysis.items() if k not in ["pivots", "structures"]}
            serializable["pivots"] = {k: len(v) for k, v in analysis["pivots"].items()}
            serializable["structures"] = {k: len(v) for k, v in analysis["structures"].items()}
            serializable_analyses.append(serializable)
        
        json.dump(serializable_analyses, f, indent=2)
    
    # Save templates
    with open(os.path.join(output_dir, "templates.json"), "w") as f:
        json.dump(templates, f, indent=2)
    
    # Create a summary report
    with open(os.path.join(output_dir, "summary_report.md"), "w") as f:
        f.write("# Trace Analysis Summary\n\n")
        
        f.write("## Dataset Statistics\n\n")
        f.write(f"- Number of examples: {len(analyses)}\n")
        f.write(f"- Average trace length: {sum(a['token_count'] for a in analyses) / len(analyses):.2f} tokens\n")
        f.write(f"- Average paragraphs: {sum(a['paragraph_count'] for a in analyses) / len(analyses):.2f}\n")
        f.write(f"- Average equations: {sum(a['equation_count'] for a in analyses) / len(analyses):.2f}\n\n")
        
        f.write("## Pivot Analysis\n\n")
        f.write(f"- Average total pivots per trace: {sum(a['total_pivots'] for a in analyses) / len(analyses):.2f}\n\n")
        
        f.write("### Pivot Type Frequency\n\n")
        f.write("| Pivot Type | Traces | % of Traces | Average Count |\n")
        f.write("|------------|--------|-------------|---------------|\n")
        
        for pivot_type in PIVOT_PATTERNS:
            count = sum(1 for a in analyses if a["pivot_counts"][pivot_type] > 0)
            percent = count / len(analyses) * 100
            avg = sum(a["pivot_counts"][pivot_type] for a in analyses) / len(analyses)
            f.write(f"| {pivot_type} | {count} | {percent:.1f}% | {avg:.2f} |\n")
        
        f.write("\n## Structure Analysis\n\n")
        f.write(f"- Average total structures per trace: {sum(a['total_structures'] for a in analyses) / len(analyses):.2f}\n\n")
        
        f.write("### Structure Type Frequency\n\n")
        f.write("| Structure Type | Traces | % of Traces | Average Count |\n")
        f.write("|---------------|--------|-------------|---------------|\n")
        
        for structure_type in STRUCTURE_PATTERNS:
            count = sum(1 for a in analyses if a["structure_counts"][structure_type] > 0)
            percent = count / len(analyses) * 100
            avg = sum(a["structure_counts"][structure_type] for a in analyses) / len(analyses)
            f.write(f"| {structure_type} | {count} | {percent:.1f}% | {avg:.2f} |\n")
        
        f.write("\n## Templates\n\n")
        f.write("| Template (Pivot Profile) | Count | % of Dataset | Avg Tokens | Avg Paragraphs | Avg Equations |\n")
        f.write("|--------------------------|-------|--------------|------------|----------------|---------------|\n")
        
        for template in templates[:10]:  # Top 10 templates
            profile = template["pivot_profile"]
            count = template["count"]
            percent = count / len(analyses) * 100
            tokens = template["avg_token_count"]
            paragraphs = template["avg_paragraph_count"]
            equations = template["avg_equation_count"]
            f.write(f"| {profile} | {count} | {percent:.1f}% | {tokens:.2f} | {paragraphs:.2f} | {equations:.2f} |\n")
        
        f.write("\n\nFull analysis results saved to JSON files in the output directory.\n")
    
    logging.info(f"Analysis complete. Results saved to {output_dir}")
    return analyses, templates

def main():
    import argparse
    
    parser = argparse.ArgumentParser(description="Analyze reasoning traces and extract templates")
    parser.add_argument("--dataset", type=str, required=True,
                        help="Path to the dataset")
    parser.add_argument("--output_dir", type=str, required=True,
                        help="Directory to save analysis results")
    parser.add_argument("--samples", type=int, default=None,
                        help="Number of samples to analyze (default: all)")
    parser.add_argument("--no_viz", action="store_true",
                        help="Skip visualization generation")
    
    args = parser.parse_args()
    
    # Analyze traces and identify templates
    analyses, templates = analyze_traces_in_dataset(
        args.dataset,
        args.output_dir,
        args.samples
    )
    
    # Generate visualizations if requested
    if not args.no_viz:
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            # Set style
            sns.set_style("whitegrid")
            
            # Create directory for visualizations
            viz_dir = os.path.join(args.output_dir, "visualizations")
            os.makedirs(viz_dir, exist_ok=True)
            
            # Plot pivot type frequency
            pivot_counts = {pivot_type: sum(1 for a in analyses if a["pivot_counts"][pivot_type] > 0) 
                           for pivot_type in PIVOT_PATTERNS}
            
            plt.figure(figsize=(12, 6))
            plt.bar(pivot_counts.keys(), [pivot_counts[p] for p in pivot_counts.keys()])
            plt.xticks(rotation=45, ha="right")
            plt.title("Pivot Type Frequency")
            plt.ylabel("Number of Traces")
            plt.tight_layout()
            plt.savefig(os.path.join(viz_dir, "pivot_frequency.png"))
            
            # Plot structure type frequency
            structure_counts = {structure_type: sum(1 for a in analyses if a["structure_counts"][structure_type] > 0) 
                               for structure_type in STRUCTURE_PATTERNS}
            
            plt.figure(figsize=(12, 6))
            plt.bar(structure_counts.keys(), [structure_counts[s] for s in structure_counts.keys()])
            plt.xticks(rotation=45, ha="right")
            plt.title("Structure Type Frequency")
            plt.ylabel("Number of Traces")
            plt.tight_layout()
            plt.savefig(os.path.join(viz_dir, "structure_frequency.png"))
            
            # Plot template frequency (top 10)
            top_templates = templates[:10]
            template_labels = [t["pivot_profile"] for t in top_templates]
            template_counts = [t["count"] for t in top_templates]
            
            plt.figure(figsize=(12, 6))
            plt.bar(range(len(template_labels)), template_counts)
            plt.xticks(range(len(template_labels)), template_labels, rotation=45, ha="right")
            plt.title("Top 10 Template Frequencies")
            plt.ylabel("Number of Traces")
            plt.tight_layout()
            plt.savefig(os.path.join(viz_dir, "template_frequency.png"))
            
            logging.info(f"Visualizations saved to {viz_dir}")
            
        except ImportError:
            logging.warning("Matplotlib or seaborn not available. Skipping visualizations.")
    
if __name__ == "__main__":
    main() 