

import os
import json
import re
import sys
import numpy as np
from tqdm import tqdm
from datasets import load_dataset
from openai import OpenAI

# Fix Unicode encoding issues on Windows
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

# --- SCRIPT TO DIAGNOSE THE NATURE OF REASONING FAILURES ---

# ------------------------------
# Config
# ------------------------------
# Response files - all available model-dataset combinations
RESPONSE_FILES = [
    "responses_gpt_4o_mini_GSM8K.jsonl",
    "responses_gpt_4o_mini_ASDiv.jsonl",
    "responses_gpt_4o_mini_SVAMP.jsonl",
    "responses_gpt_3.5_turbo_1106_GSM8K.jsonl",
    "responses_gpt_3.5_turbo_1106_ASDiv.jsonl",
    "responses_gpt_3.5_turbo_1106_SVAMP.jsonl"
]

# Datasets to include (exclude MathQA)
INCLUDED_DATASETS = ["GSM8K", "ASDiv", "SVAMP"]

# Model to use for analysis
ANALYSIS_MODEL = "gpt-4o-mini"

# Output files will be named: failure_analysis_{model}_{dataset}.json

# ------------------------------
# Load API key & Initialize Client
# ------------------------------
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
    raise ValueError("OPENAI_API_KEY environment variable is not set")
client = OpenAI(api_key=api_key)

# ------------------------------
# Helper Functions
# ------------------------------
def extract_number_from_answer(answer_str):
    """Extract numerical answer from ground truth string."""
    if not answer_str:
        return None

    answer_str = str(answer_str).strip()

    # For GSM8K format: "#### 21"
    if '####' in answer_str:
        match_str = answer_str.split('####')[-1].strip()
    else:
        match_str = answer_str

    # Remove units in parentheses like "(pies)", "(books)", etc
    match_str = re.sub(r'\s*\([^)]*\)\s*', '', match_str)

    # Handle ratios (e.g., "2:3")
    if ':' in match_str:
        return match_str.strip()

    # Remove commas and convert to float
    match_str = match_str.replace(',', '').strip()

    try:
        return float(match_str)
    except (ValueError, TypeError):
        # Return as string for non-numeric values (like ratios)
        return match_str if match_str else None

def safe_to_float(value):
    """Safely convert a value to float or return as string for non-numeric."""
    if value is None:
        return None

    value_str = str(value).strip()

    # Handle ratios
    if ':' in value_str:
        return value_str

    try:
        return float(value_str)
    except (ValueError, TypeError):
        return value_str if value_str else None

def is_correct(ground_truth, model_answer, dataset=None):
    """Check if model answer matches ground truth."""
    # For MathQA, ground truth is a letter (a, b, c, d, e)
    if dataset == "MathQA":
        gt_str = str(ground_truth).strip().lower()
        ma_str = str(model_answer).strip().lower() if model_answer else ""
        return gt_str == ma_str

    # For other datasets, compare numerical values or strings
    gt = extract_number_from_answer(ground_truth)
    ma = safe_to_float(model_answer)

    if gt is None or ma is None:
        return False

    # If both are strings (e.g., ratios), do string comparison
    if isinstance(gt, str) and isinstance(ma, str):
        return gt.strip().lower() == ma.strip().lower()

    # If both are numbers, use close comparison
    try:
        gt_float = float(gt)
        ma_float = float(ma)
        return np.isclose(gt_float, ma_float, rtol=1e-5)
    except (ValueError, TypeError):
        return False

def split_into_steps(text):
    """Splits the combined reasoning text back into individual steps."""
    if not text:
        return []
    # This regex handles variations like "Step 1:", "Step 1", "Step1:", "step 1 -" etc.
    # It splits the string by looking for the pattern "Step" followed by a number.
    steps = re.split(r'\n?Step\s*\d+[:\-]?\s*', text)
    # The first element of the split is usually empty, so we discard it.
    cleaned_steps = [step.strip() for step in steps if step.strip()]
    return cleaned_steps

# ------------------------------
# Main Analysis Logic
# ------------------------------
def analyze_failures_for_file(response_file, model_name, dataset_name):
    """Analyze failures for a specific model-dataset combination."""
    print(f"\n{'=' * 80}")
    print(f"Analyzing: {model_name} + {dataset_name}")
    print(f"{'=' * 80}")

    incorrect_examples = []

    with open(response_file, "r", encoding="utf-8") as f:
        for line in f:
            entry = json.loads(line)
            dataset = entry.get("dataset")

            # Skip if not the dataset we want
            if dataset != dataset_name:
                continue

            # Check if response is incorrect
            ground_truth = entry.get("ground_truth")
            model_answer = entry.get("model_answer")

            if not is_correct(ground_truth, model_answer, dataset):
                incorrect_examples.append({
                    "id": entry["id"],
                    "model": model_name,
                    "dataset": dataset,
                    "question": entry["question"],
                    "reasoning": entry["response"],
                    "ground_truth": ground_truth,
                    "model_answer": model_answer
                })

    print(f"Found {len(incorrect_examples)} incorrect responses to diagnose.")

    if not incorrect_examples:
        print("No incorrect responses found. Nothing to diagnose.")
        return None

    print(f"Using {ANALYSIS_MODEL} to analyze each reasoning step...")

    system_prompt = """
    You are an expert logical analyst. Your task is to analyze a single reasoning step from a math problem and categorize its error type based on the original question.

    The possible error categories are:
    - "Consistent": The reasoning step is logically sound and correctly uses information from the question or previous correct steps.
    - "Calculation Error": The logic is correct, but a basic arithmetic error occurred (e.g., "2 + 2 = 5").
    - "Reasoning Error": The step uses the correct numbers but applies a flawed logical operation (e.g., adding when subtraction was needed).
    - "Misinterpretation Error": The step misunderstands the meaning of words, relationships, or the core request of the question.
    - "Factual Invention": The step introduces new information or numbers not present in the question.

    You must respond with ONLY ONE of these category names. Your entire response should be just the category name.
    """

    analysis_report = []

    for ex in tqdm(incorrect_examples, desc=f"Diagnosing {model_name}+{dataset_name}"):
        question = ex['question']
        reasoning = ex['reasoning']
        steps = split_into_steps(reasoning)

        first_error_found = None

        if not steps:
            first_error_found = {
                "step_number": 0,
                "error_type": "Parsing Error",
                "step_content": "Could not parse reasoning into steps."
            }
        else:
            for i, step in enumerate(steps):
                user_prompt = f"""
                Original Question: "{question}"
                Reasoning Step to Analyze: "{step}"
                What is the error category for this reasoning step?
                """

                try:
                    response = client.chat.completions.create(
                        model=ANALYSIS_MODEL,
                        messages=[
                            {"role": "system", "content": system_prompt},
                            {"role": "user", "content": user_prompt}
                        ],
                        temperature=0.0,
                        max_tokens=20
                    )
                    error_category = response.choices[0].message.content.strip().replace('"', '')
                except Exception as e:
                    print(f"\nAPI call failed for response ID {ex['id']}: {e}")
                    error_category = "API_Error"

                if error_category != "Consistent":
                    first_error_found = {
                        "step_number": i + 1,
                        "error_type": error_category,
                        "step_content": step
                    }
                    break

        analysis_entry = {
            "response_id": ex['id'],
            "model": ex['model'],
            "dataset": ex['dataset'],
            "question": question,
            "ground_truth": ex['ground_truth'],
            "model_answer": ex['model_answer'],
            "full_reasoning_trace": reasoning,
            "first_error": first_error_found if first_error_found else "No specific error found by analyst."
        }
        analysis_report.append(analysis_entry)

    # Save to model-dataset specific file
    output_file = f"failure_analysis_{model_name}_{dataset_name}.json"
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump(analysis_report, f, indent=2)

    print(f" Saved analysis to '{output_file}'")
    print(f"Total failures analyzed: {len(analysis_report)}")

    return len(analysis_report)

def main():
    print("=" * 80)
    print("STEP-BY-STEP FAILURE DIAGNOSIS - PER MODEL-DATASET COMBINATION")
    print("=" * 80)

    total_analyzed = 0

    # Process each response file separately
    for response_file in RESPONSE_FILES:
        if not os.path.exists(response_file):
            print(f"\nWarning: {response_file} not found, skipping...")
            continue

        # Extract model name and dataset from filename
        # e.g., "responses_gpt_4o_mini_GSM8K.jsonl" -> model: "gpt_4o_mini", dataset: "GSM8K"
        filename = response_file.replace('responses_', '').replace('.jsonl', '')
        parts = filename.rsplit('_', 1)

        if len(parts) == 2:
            model_name, dataset_name = parts
        else:
            # Old format - skip or handle differently
            print(f"\nSkipping {response_file} - old format without dataset in filename")
            continue

        # Analyze this specific model-dataset combination
        count = analyze_failures_for_file(response_file, model_name, dataset_name)
        if count:
            total_analyzed += count

    print("\n" + "=" * 80)
    print(f"ALL ANALYSIS COMPLETE")
    print(f"Total failures analyzed across all combinations: {total_analyzed}")
    print("=" * 80)

if __name__ == "__main__":
    main()