#!/usr/bin/env python3
"""
Statistical significance tests for best/worst clusters across model-dataset combinations.
"""
import sys
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

import numpy as np
from scipy.stats import chi2_contingency, fisher_exact
import json
import os

# Configuration
CLUSTER_FILES = [
    "clusters_gpt_4o_mini_GSM8K.json",
    "clusters_gpt_4o_mini_ASDiv.json",
    "clusters_gpt_4o_mini_SVAMP.json",
    "clusters_gpt_3.5_turbo_1106_GSM8K.json",
    "clusters_gpt_3.5_turbo_1106_ASDiv.json",
    "clusters_gpt_3.5_turbo_1106_SVAMP.json"
]

RESPONSE_FILES = [
    "responses_gpt_4o_mini_GSM8K.jsonl",
    "responses_gpt_4o_mini_ASDiv.jsonl",
    "responses_gpt_4o_mini_SVAMP.jsonl",
    "responses_gpt_3.5_turbo_1106_GSM8K.jsonl",
    "responses_gpt_3.5_turbo_1106_ASDiv.jsonl",
    "responses_gpt_3.5_turbo_1106_SVAMP.jsonl"
]

MIN_SENTENCES = 10  # Only test clusters with at least this many sentences
TOP_N = 2  # Test top 2 best and worst clusters

def perform_statistical_test(name1, success1, total1, name2, success2, total2):
    """
    Performs Fisher's Exact Test to compare two groups.
    """
    fail1 = total1 - success1
    fail2 = total2 - success2

    if (success1 == 0 and fail1 == 0) or (success2 == 0 and fail2 == 0):
        return None

    contingency_table = np.array([[success1, fail1],
                                  [success2, fail2]])

    # Fisher's Exact Test (more accurate for small samples)
    _, p_value_fisher = fisher_exact(contingency_table)

    # Chi-Squared Test
    try:
        chi2, p_value_chi2, _, _ = chi2_contingency(contingency_table)
    except:
        chi2, p_value_chi2 = None, None

    return {
        'contingency_table': contingency_table.tolist(),
        'chi2_statistic': float(chi2) if chi2 is not None else None,
        'chi2_p_value': float(p_value_chi2) if p_value_chi2 is not None else None,
        'fisher_p_value': float(p_value_fisher),
        'significant': bool(p_value_fisher < 0.05)
    }

def calculate_overall_baseline(response_file, dataset_name):
    """Calculate overall correctness rate for a model-dataset combination."""
    if not os.path.exists(response_file):
        return None

    total_sentences = 0
    correct_sentences = 0

    # Estimate: each problem has ~5 reasoning steps on average
    # We'll count problems and multiply
    total_problems = 0
    correct_problems = 0

    with open(response_file, 'r', encoding='utf-8') as f:
        for line in f:
            entry = json.loads(line)
            if entry.get('dataset') != dataset_name:
                continue

            total_problems += 1

            # Check if correct (simplified check)
            ground_truth = entry.get('ground_truth', '')
            model_answer = entry.get('model_answer', '')

            # Extract final answer from ground truth
            if '####' in str(ground_truth):
                gt_value = str(ground_truth).split('####')[-1].strip()
            else:
                gt_value = str(ground_truth).strip()

            # Simple string comparison (you may need to enhance this)
            if str(model_answer).strip() == gt_value.replace(',', '').strip():
                correct_problems += 1

    # Estimate: 5 sentences per problem on average
    avg_sentences = 5
    total_sentences = total_problems * avg_sentences
    correct_sentences = correct_problems * avg_sentences

    return {
        'total_sentences': total_sentences,
        'correct_sentences': correct_sentences,
        'rate': correct_sentences / total_sentences if total_sentences > 0 else 0
    }

def load_cluster_file(file_path):
    """Load cluster data from file."""
    if not os.path.exists(file_path):
        return None
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def analyze_combination(cluster_file, response_file):
    """Analyze statistical significance for one model-dataset combination."""
    cluster_data = load_cluster_file(cluster_file)
    if not cluster_data:
        return None

    model = cluster_data['model']
    dataset = cluster_data['dataset']
    clusters = cluster_data['clusters']

    print(f"\n{'=' * 80}")
    print(f"Model: {model} | Dataset: {dataset}")
    print(f"{'=' * 80}")

    # Calculate overall baseline
    baseline = calculate_overall_baseline(response_file, dataset)
    if not baseline:
        print("Could not calculate baseline")
        return None

    print(f"\nOverall baseline:")
    print(f"  Total sentences (estimated): {baseline['total_sentences']}")
    print(f"  Correct sentences (estimated): {baseline['correct_sentences']}")
    print(f"  Correctness rate: {baseline['rate']:.1%}")

    # Filter valid clusters
    valid_clusters = [c for c in clusters if c.get('sentence_count', 0) >= MIN_SENTENCES]
    if not valid_clusters:
        print("No valid clusters found")
        return None

    # Sort by correctness
    sorted_clusters = sorted(valid_clusters, key=lambda x: x.get('correctness_percentage', 0))

    # Get worst and best
    worst_clusters = sorted_clusters[:TOP_N]
    best_clusters = sorted_clusters[-TOP_N:][::-1]

    results = []

    # Test worst clusters
    print(f"\n{'─' * 80}")
    print("WORST CLUSTERS vs. BASELINE")
    print(f"{'─' * 80}")

    for i, cluster in enumerate(worst_clusters, 1):
        cluster_total = cluster['sentence_count']
        cluster_rate = cluster['correctness_percentage'] / 100.0
        cluster_success = int(cluster_total * cluster_rate)
        cluster_name = f"Cluster {cluster['cluster_id']}: {cluster.get('auto_label', 'Unknown')}"

        test_result = perform_statistical_test(
            cluster_name, cluster_success, cluster_total,
            "Baseline", baseline['correct_sentences'], baseline['total_sentences']
        )

        if test_result:
            print(f"\n{i}. {cluster_name}")
            print(f"   Count: {cluster_total}, Correctness: {cluster_rate:.1%}")
            print(f"   Fisher's Exact Test p-value: {test_result['fisher_p_value']:.4e}")
            print(f"   Significant: {'✓ YES' if test_result['significant'] else '✗ NO'}")

            results.append({
                'model': model,
                'dataset': dataset,
                'category': 'worst',
                'rank': i,
                'cluster_id': cluster['cluster_id'],
                'label': cluster.get('auto_label', 'Unknown'),
                'count': cluster_total,
                'correctness': cluster_rate,
                **test_result
            })

    # Test best clusters
    print(f"\n{'─' * 80}")
    print("BEST CLUSTERS vs. BASELINE")
    print(f"{'─' * 80}")

    for i, cluster in enumerate(best_clusters, 1):
        cluster_total = cluster['sentence_count']
        cluster_rate = cluster['correctness_percentage'] / 100.0
        cluster_success = int(cluster_total * cluster_rate)
        cluster_name = f"Cluster {cluster['cluster_id']}: {cluster.get('auto_label', 'Unknown')}"

        test_result = perform_statistical_test(
            cluster_name, cluster_success, cluster_total,
            "Baseline", baseline['correct_sentences'], baseline['total_sentences']
        )

        if test_result:
            print(f"\n{i}. {cluster_name}")
            print(f"   Count: {cluster_total}, Correctness: {cluster_rate:.1%}")
            print(f"   Fisher's Exact Test p-value: {test_result['fisher_p_value']:.4e}")
            print(f"   Significant: {'✓ YES' if test_result['significant'] else '✗ NO'}")

            results.append({
                'model': model,
                'dataset': dataset,
                'category': 'best',
                'rank': i,
                'cluster_id': cluster['cluster_id'],
                'label': cluster.get('auto_label', 'Unknown'),
                'count': cluster_total,
                'correctness': cluster_rate,
                **test_result
            })

    return results

def main():
    print("=" * 80)
    print("STATISTICAL SIGNIFICANCE TESTING")
    print("Testing best/worst clusters vs. baseline for each model-dataset combination")
    print("=" * 80)

    all_results = []

    for cluster_file, response_file in zip(CLUSTER_FILES, RESPONSE_FILES):
        results = analyze_combination(cluster_file, response_file)
        if results:
            all_results.extend(results)

    # Save results
    with open('statistical_test_results.json', 'w', encoding='utf-8') as f:
        json.dump(all_results, f, indent=2)

    print(f"\n{'=' * 80}")
    print(f"Saved results to: statistical_test_results.json")
    print(f"Total tests performed: {len(all_results)}")
    print(f"{'=' * 80}")

if __name__ == "__main__":
    main()
