"""
Script for generating SmolTraces-HardCoded-Wrong (ST-HC-W) datasets.
These datasets have hardcoded reasoning traces that lead to plausible but incorrect answers.
"""

import os
import sys
import json
import argparse
import logging
import time
import random
import re
from typing import Dict, List, Any, Optional, Tuple
import requests
from tqdm import tqdm
from datasets import load_dataset, Dataset

# Add parent directory to path for imports
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data_generation.synthetic_trace_generation import extract_pivots

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Constants
MAX_ATTEMPTS = 3
SLEEP_TIME = 1
RESULTS_DIR = "results/wrong_answer_traces"
API_BASE_URL = "https://api.openai.com/v1/chat/completions"

def extract_answer(text: str) -> str:
    """
    Extract the answer from a response text.
    
    Args:
        text: The response text
        
    Returns:
        The extracted answer
    """
    # Look for specific formats
    patterns = [
        r"<\|begin_of_solution\|>(.*?)<\|end_of_solution\|>",
        r"\\boxed{([^}]*)}",
        r"The\s+(?:final\s+)?answer\s+is:?\s+(.*?)(?:\.|$)",
        r"Therefore,\s+(.*?)(?:\.|$)",
        r"Thus,\s+(.*?)(?:\.|$)",
        r"In\s+conclusion,\s+(.*?)(?:\.|$)"
    ]
    
    for pattern in patterns:
        match = re.search(pattern, text, re.DOTALL)
        if match:
            return match.group(1).strip()
    
    # If no patterns match, just return the last sentence
    sentences = re.split(r'(?<=[.!?])\s+', text)
    if sentences:
        return sentences[-1].strip()
    
    return ""

def generate_incorrect_answer(question: str, correct_answer: str, api_key: str) -> Optional[str]:
    """
    Generate a plausible but incorrect answer for a question.
    
    Args:
        question: The question to answer incorrectly
        correct_answer: The correct answer to avoid
        api_key: OpenAI API key
        
    Returns:
        A plausible but incorrect answer
    """
    prompt = f"""You are helping to create a dataset of incorrect but plausible answers for educational purposes.

Given the following question and its correct answer, generate a different answer that:
1. Is INCORRECT but seems plausible
2. Contains a common mistake or misconception that a student might make
3. Is presented confidently as if it were correct
4. Is clearly different from the correct answer

Question: {question}

Correct answer (DO NOT USE THIS DIRECTLY): {correct_answer}

Your task is to provide an incorrect answer that seems believable. Present only the incorrect answer, without explaining that it's wrong.

Incorrect answer:"""
    
    headers = {
        "Content-Type": "application/json",
        "Authorization": f"Bearer {api_key}"
    }
    
    data = {
        "model": "gpt-4o-mini",
        "messages": [{"role": "user", "content": prompt}],
        "max_tokens": 500,
        "temperature": 0.7
    }
    
    for attempt in range(MAX_ATTEMPTS):
        try:
            response = requests.post(API_BASE_URL, headers=headers, json=data)
            response.raise_for_status()
            result = response.json()
            incorrect_answer = result["choices"][0]["message"]["content"].strip()
            
            # Verify it's different from the correct answer
            if incorrect_answer and incorrect_answer.lower() != correct_answer.lower():
                return incorrect_answer
            else:
                logging.warning("Generated answer too similar to correct answer, retrying")
                
        except Exception as e:
            logging.error(f"API call failed on attempt {attempt+1}: {e}")
            if attempt < MAX_ATTEMPTS - 1:
                logging.info(f"Retrying in {SLEEP_TIME} seconds...")
                time.sleep(SLEEP_TIME)
            else:
                logging.error("All attempts failed")
                return None
                
    # If we couldn't generate a sufficiently different answer
    return None

def modify_trace_with_wrong_answer(thinking: str, correct_answer: str, wrong_answer: str) -> str:
    """
    Modify a reasoning trace to lead to an incorrect answer.
    
    Args:
        thinking: Original reasoning trace
        correct_answer: The correct answer
        wrong_answer: The incorrect answer to lead to
        
    Returns:
        Modified reasoning trace
    """
    # Split the thinking into paragraphs
    paragraphs = thinking.split("\n\n")
    
    if len(paragraphs) <= 3:
        # Not enough paragraphs to make a meaningful modification
        return thinking
    
    # Identify conclusion paragraphs (last ~20% of paragraphs)
    conclusion_start = max(len(paragraphs) - 3, int(len(paragraphs) * 0.8))
    
    # Modify conclusion paragraphs to lead to the wrong answer
    modified_paragraphs = paragraphs[:conclusion_start]
    
    # Add a pivot sentence to change direction
    pivot_sentences = [
        "Wait, I made a mistake earlier.",
        "Actually, I need to reconsider my approach.",
        "Let me recalculate this.",
        "On second thought, I believe I misunderstood a key aspect.",
        "I need to correct an error in my previous steps."
    ]
    
    modified_paragraphs.append(random.choice(pivot_sentences) + "\n\n" + 
                              "Looking more carefully at the problem, I realize I need to adjust my approach.")
    
    # Add modified conclusion that leads to the wrong answer
    modified_paragraphs.append(f"After recalculating and verifying my work, I conclude that {wrong_answer}")
    
    # Combine into a single text
    modified_thinking = "\n\n".join(modified_paragraphs)
    
    return modified_thinking

def generate_st_hc_w_dataset(input_dataset_path: str, num_samples: int, api_key: str, output_dir: str) -> Dataset:
    """
    Generate the SmolTraces-HardCoded-Wrong (ST-HC-W) dataset by modifying ST-HC.
    
    Args:
        input_dataset_path: Path to the ST-HC dataset
        num_samples: Maximum number of samples to generate
        api_key: OpenAI API key
        output_dir: Directory to save results
        
    Returns:
        Hugging Face dataset with ST-HC-W samples
    """
    os.makedirs(output_dir, exist_ok=True)
    os.makedirs(os.path.join(output_dir, "individual"), exist_ok=True)
    
    # Load ST-HC dataset
    st_hc_dataset = load_dataset(input_dataset_path)
    if isinstance(st_hc_dataset, dict):
        st_hc_dataset = st_hc_dataset["train"]
    
    # Limit to requested number of samples
    st_hc_dataset = st_hc_dataset.shuffle(seed=42).select(range(min(len(st_hc_dataset), num_samples)))
    
    # Process dataset
    results = []
    
    for i, example in enumerate(tqdm(st_hc_dataset, desc="Generating ST-HC-W traces")):
        question = example["question"]
        thinking = example["thinking"]
        correct_answer = example["answer"]
        
        # Generate incorrect answer
        wrong_answer = generate_incorrect_answer(question, correct_answer, api_key)
        
        if not wrong_answer:
            logging.warning(f"Could not generate incorrect answer for sample {i}, skipping")
            continue
        
        # Modify the trace
        modified_thinking = modify_trace_with_wrong_answer(thinking, correct_answer, wrong_answer)
        
        # Create result
        result = {
            "id": f"sthcw_{i:04d}",
            "question": question,
            "thinking": modified_thinking,
            "answer": wrong_answer,
            "original_thinking": thinking,
            "original_answer": correct_answer,
            "success": True
        }
        
        # Extract pivots from modified thinking
        pivots = extract_pivots(modified_thinking)
        pivot_stats = {p_type: len(instances) for p_type, instances in pivots.items()}
        result["pivot_stats"] = pivot_stats
        
        # Save individual result
        with open(os.path.join(output_dir, "individual", f"{result['id']}.json"), "w") as f:
            json.dump(result, f, indent=2)
        
        results.append(result)
    
    # Create dataset
    st_hc_w_dataset = Dataset.from_list(results)
    
    # Save dataset
    st_hc_w_dataset.save_to_disk(os.path.join(output_dir, "st_hc_w_dataset"))
    
    # Save in Hugging Face format
    st_hc_w_dataset_dict = {
        "id": [r["id"] for r in results],
        "question": [r["question"] for r in results],
        "thinking": [r["thinking"] for r in results],
        "answer": [r["answer"] for r in results],
        "original_thinking": [r["original_thinking"] for r in results],
        "original_answer": [r["original_answer"] for r in results],
        "pivot_stats": [json.dumps(r["pivot_stats"]) for r in results]
    }
    
    st_hc_w_hf_dataset = Dataset.from_dict(st_hc_w_dataset_dict)
    
    return st_hc_w_hf_dataset

def analyze_dataset_differences(dataset: Dataset) -> Dict[str, Any]:
    """
    Analyze differences between original and modified datasets.
    
    Args:
        dataset: The ST-HC-W dataset
        
    Returns:
        Dictionary with analysis results
    """
    # Parse pivot stats
    pivot_stats = []
    for ps_str in dataset["pivot_stats"]:
        pivot_stats.append(json.loads(ps_str))
    
    # Count pivot types
    pivot_type_counts = {}
    for ps in pivot_stats:
        for pivot_type, count in ps.items():
            if pivot_type not in pivot_type_counts:
                pivot_type_counts[pivot_type] = []
            pivot_type_counts[pivot_type].append(count)
    
    # Compute statistics
    analysis = {
        "num_samples": len(dataset),
        "pivot_type_coverage": {pivot_type: len(counts) for pivot_type, counts in pivot_type_counts.items()},
        "pivot_type_percent": {pivot_type: len(counts) / len(dataset) * 100 
                             for pivot_type, counts in pivot_type_counts.items()},
        "pivot_type_avg_counts": {pivot_type: sum(counts) / len(counts) if counts else 0 
                                for pivot_type, counts in pivot_type_counts.items()},
        "thinking_length_avg": sum(len(t) for t in dataset["thinking"]) / len(dataset),
        "answer_length_avg": sum(len(a) for a in dataset["answer"]) / len(dataset),
        "original_thinking_length_avg": sum(len(t) for t in dataset["original_thinking"]) / len(dataset),
        "original_answer_length_avg": sum(len(a) for a in dataset["original_answer"]) / len(dataset),
    }
    
    # Check for unique pivots in the wrong answers
    analysis["unique_pivots_in_wrong"] = {
        "realization": sum(1 for ps in pivot_stats if "realization" in ps),
        "contradiction": sum(1 for ps in pivot_stats if "contradiction" in ps)
    }
    
    return analysis

def main():
    parser = argparse.ArgumentParser(description="Generate SmolTraces-HardCoded-Wrong dataset")
    parser.add_argument("--input_dataset", type=str, required=True,
                        help="Path to the ST-HC dataset")
    parser.add_argument("--output_dir", type=str, default=RESULTS_DIR,
                        help="Directory to save results")
    parser.add_argument("--num_samples", type=int, default=10,
                        help="Number of samples to generate")
    parser.add_argument("--api_key", type=str, required=True,
                        help="OpenAI API key")
    
    args = parser.parse_args()
    
    # Generate dataset
    st_hc_w_dataset = generate_st_hc_w_dataset(
        args.input_dataset,
        args.num_samples,
        args.api_key,
        args.output_dir
    )
    
    # Analyze dataset
    analysis = analyze_dataset_differences(st_hc_w_dataset)
    
    # Save analysis
    with open(os.path.join(args.output_dir, "analysis.json"), "w") as f:
        json.dump(analysis, f, indent=2)
    
    # Print summary
    logging.info(f"Generated {len(st_hc_w_dataset)} ST-HC-W traces")
    logging.info(f"Average thinking length: {analysis['thinking_length_avg']:.2f} characters")
    
    # Save summary report
    with open(os.path.join(args.output_dir, "summary_report.md"), "w") as f:
        f.write("# SmolTraces-HardCoded-Wrong Dataset Generation Summary\n\n")
        
        f.write("## Dataset Statistics\n\n")
        f.write(f"- Number of samples: {analysis['num_samples']}\n")
        f.write(f"- Average thinking length: {analysis['thinking_length_avg']:.2f} characters (original: {analysis['original_thinking_length_avg']:.2f})\n")
        f.write(f"- Average answer length: {analysis['answer_length_avg']:.2f} characters (original: {analysis['original_answer_length_avg']:.2f})\n\n")
        
        f.write("## Pivot Type Coverage\n\n")
        f.write("| Pivot Type | Samples | % of Traces | Average Count |\n")
        f.write("|------------|---------|-------------|---------------|\n")
        for pivot_type in sorted(analysis["pivot_type_coverage"].keys()):
            coverage = analysis["pivot_type_coverage"][pivot_type]
            percent = analysis["pivot_type_percent"][pivot_type]
            avg_count = analysis["pivot_type_avg_counts"][pivot_type]
            f.write(f"| {pivot_type} | {coverage} | {percent:.1f}% | {avg_count:.2f} |\n")
        
        f.write("\n## Unique Pivots in Wrong Answers\n\n")
        f.write(f"- Realization pivots: {analysis['unique_pivots_in_wrong']['realization']} samples ({analysis['unique_pivots_in_wrong']['realization']/analysis['num_samples']*100:.1f}%)\n")
        f.write(f"- Contradiction pivots: {analysis['unique_pivots_in_wrong']['contradiction']} samples ({analysis['unique_pivots_in_wrong']['contradiction']/analysis['num_samples']*100:.1f}%)\n")
        
        f.write("\n\nST-HC-W dataset saved to: " + os.path.join(args.output_dir, "st_hc_w_dataset"))

if __name__ == "__main__":
    main() 