import os
import json
import re
import numpy as np
from collections import defaultdict

# ------------------------------
# Config
# ------------------------------
# Response files - mix of old and new format
# Old format (all datasets combined): responses_gpt_3.5_turbo_1106.jsonl
# New format (per dataset): responses_gpt_4o_mini_GSM8K.jsonl, etc.
RESPONSE_FILES = [
    "responses_gpt_4o_mini_GSM8K.jsonl",
    "responses_gpt_4o_mini_ASDiv.jsonl",
    "responses_gpt_4o_mini_SVAMP.jsonl",
    "responses_gpt_3.5_turbo_1106_GSM8K.jsonl",
    "responses_gpt_3.5_turbo_1106_ASDiv.jsonl",
    "responses_gpt_3.5_turbo_1106_SVAMP.jsonl"
]

# ------------------------------
# Helper Functions
# ------------------------------
def extract_number_from_answer(answer_str):
    """Extract numerical answer from ground truth string."""
    if not answer_str:
        return None

    answer_str = str(answer_str).strip()

    # For GSM8K format: "#### 21"
    if '####' in answer_str:
        match_str = answer_str.split('####')[-1].strip()
    else:
        match_str = answer_str

    # Remove units in parentheses like "(pies)", "(books)", etc
    import re
    match_str = re.sub(r'\s*\([^)]*\)\s*', '', match_str)

    # Handle ratios (e.g., "2:3")
    if ':' in match_str:
        return match_str.strip()

    # Remove commas and convert to float
    match_str = match_str.replace(',', '').strip()

    try:
        return float(match_str)
    except (ValueError, TypeError):
        # Return as string for non-numeric values (like ratios)
        return match_str if match_str else None

def safe_to_float(value):
    """Safely convert a value to float or return as string for non-numeric."""
    if value is None:
        return None

    value_str = str(value).strip()

    # Handle ratios
    if ':' in value_str:
        return value_str

    try:
        return float(value_str)
    except (ValueError, TypeError):
        return value_str if value_str else None

def is_correct(ground_truth, model_answer, dataset=None):
    """Check if model answer matches ground truth."""
    # For MathQA, ground truth is a letter (a, b, c, d, e)
    if dataset == "MathQA":
        gt_str = str(ground_truth).strip().lower()
        ma_str = str(model_answer).strip().lower() if model_answer else ""
        return gt_str == ma_str

    # For other datasets, compare numerical values or strings
    gt = extract_number_from_answer(ground_truth)
    ma = safe_to_float(model_answer)

    if gt is None or ma is None:
        return False

    # If both are strings (e.g., ratios), do string comparison
    if isinstance(gt, str) and isinstance(ma, str):
        return gt.strip().lower() == ma.strip().lower()

    # If both are numbers, use close comparison
    try:
        gt_float = float(gt)
        ma_float = float(ma)
        return np.isclose(gt_float, ma_float, rtol=1e-5)
    except (ValueError, TypeError):
        return False

# ------------------------------
# Main Analysis
# ------------------------------
def main():
    print("=" * 80)
    print("FAILURE ANALYSIS BY MODEL AND DATASET")
    print("=" * 80)

    for response_file in RESPONSE_FILES:
        if not os.path.exists(response_file):
            print(f"\nWarning: {response_file} not found, skipping...")
            continue

        # Extract model name from filename
        model_name = response_file.replace("responses_", "").replace(".jsonl", "")

        print(f"\n{'=' * 80}")
        print(f"Model: {model_name}")
        print(f"{'=' * 80}")

        # Track statistics per dataset
        stats = defaultdict(lambda: {"total": 0, "correct": 0, "incorrect": 0, "error": 0})

        # Load and analyze responses
        with open(response_file, "r", encoding="utf-8") as f:
            for line in f:
                entry = json.loads(line)
                dataset = entry["dataset"]

                # Skip MathQA dataset (uses letter answers, incompatible format)
                if dataset == "MathQA":
                    continue

                ground_truth = entry["ground_truth"]
                model_answer = entry["model_answer"]

                stats[dataset]["total"] += 1

                # Check if it's an error response
                response_text = entry.get("response", "")
                if "Error:" in response_text or model_answer is None:
                    stats[dataset]["error"] += 1
                    stats[dataset]["incorrect"] += 1
                elif is_correct(ground_truth, model_answer, dataset):
                    stats[dataset]["correct"] += 1
                else:
                    stats[dataset]["incorrect"] += 1

        # Print summary table
        print(f"\n{'Dataset':<15} {'Total':<10} {'Correct':<10} {'Failed':<10} {'Errors':<10} {'Accuracy':<10}")
        print("-" * 80)

        total_all = 0
        correct_all = 0
        incorrect_all = 0
        error_all = 0

        for dataset in sorted(stats.keys()):
            stat = stats[dataset]
            accuracy = (stat["correct"] / stat["total"] * 100) if stat["total"] > 0 else 0

            print(f"{dataset:<15} {stat['total']:<10} {stat['correct']:<10} "
                  f"{stat['incorrect']:<10} {stat['error']:<10} {accuracy:.2f}%")

            total_all += stat["total"]
            correct_all += stat["correct"]
            incorrect_all += stat["incorrect"]
            error_all += stat["error"]

        # Print overall statistics
        print("-" * 80)
        accuracy_all = (correct_all / total_all * 100) if total_all > 0 else 0
        print(f"{'OVERALL':<15} {total_all:<10} {correct_all:<10} "
              f"{incorrect_all:<10} {error_all:<10} {accuracy_all:.2f}%")

        print(f"\n{'Failure Rate:':<20} {(incorrect_all / total_all * 100):.2f}%")
        print(f"{'Error Rate:':<20} {(error_all / total_all * 100):.2f}%")

    print("\n" + "=" * 80)
    print("ANALYSIS COMPLETE")
    print("=" * 80)

if __name__ == "__main__":
    main()
