#!/usr/bin/env python3
"""
Analyze top 2 best and 2 worst accuracy clusters for each model-dataset combination.
Works with separate cluster files per combination.
"""
import sys
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

import json
import os

# Configuration
CLUSTER_FILES = [
    "clusters_gpt_4o_mini_GSM8K.json",
    "clusters_gpt_4o_mini_ASDiv.json",
    "clusters_gpt_4o_mini_SVAMP.json",
    "clusters_gpt_3.5_turbo_1106_GSM8K.json",
    "clusters_gpt_3.5_turbo_1106_ASDiv.json",
    "clusters_gpt_3.5_turbo_1106_SVAMP.json"
]

MIN_SENTENCES = 10  # Only consider clusters with at least this many sentences
TOP_N = 2  # Show top 2 best and 2 worst

def load_cluster_file(file_path):
    """Load cluster data from a single file."""
    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found, skipping...")
        return None

    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def analyze_combination(cluster_data):
    """Analyze clusters for a single model-dataset combination."""
    model = cluster_data['model']
    dataset = cluster_data['dataset']
    clusters = cluster_data['clusters']

    # Filter clusters with minimum sentences
    valid_clusters = [c for c in clusters if c['sentence_count'] >= MIN_SENTENCES]

    if not valid_clusters:
        return None

    # Sort by correctness percentage
    sorted_clusters = sorted(valid_clusters, key=lambda x: x['correctness_percentage'])

    # Get worst and best
    worst_clusters = sorted_clusters[:TOP_N]
    best_clusters = sorted_clusters[-TOP_N:][::-1]  # Reverse to show best first

    return {
        'model': model,
        'dataset': dataset,
        'total_clusters': len(valid_clusters),
        'worst': worst_clusters,
        'best': best_clusters
    }

def print_results(results):
    """Print top 2 best and 2 worst clusters for each combination."""

    print("=" * 80)
    print("TOP 2 BEST & WORST ACCURACY CLUSTERS BY MODEL-DATASET COMBINATION")
    print("=" * 80)
    print(f"Minimum sentences per cluster: {MIN_SENTENCES}")
    print(f"Showing top {TOP_N} best and worst clusters per combination")
    print()

    for result in results:
        if not result:
            continue

        model = result['model']
        dataset = result['dataset']

        print("=" * 80)
        print(f"MODEL: {model}")
        print(f"DATASET: {dataset}")
        print("=" * 80)
        print(f"Total clusters analyzed: {result['total_clusters']}")
        print()

        # Print WORST clusters
        print(f"  { str(TOP_N) + ' CLUSTERS (Lowest Accuracy)':^76}")
        print("-" * 80)
        for i, cluster in enumerate(result['worst'], 1):
            print(f"\n{i}. Cluster {cluster['cluster_id']}: {cluster['auto_label']}")
            print(f"   Accuracy: {cluster['correctness_percentage']:.1f}%")
            print(f"   Total sentences: {cluster['sentence_count']}")
            print(f"   Sample sentences:")
            for sentence in cluster['sample_sentences'][:3]:
                print(f"     - {sentence[:100]}{'...' if len(sentence) > 100 else ''}")

        # Print BEST clusters
        print(f"\n  { str(TOP_N) + ' CLUSTERS (Highest Accuracy)':^76}")
        print("-" * 80)
        for i, cluster in enumerate(result['best'], 1):
            print(f"\n{i}. Cluster {cluster['cluster_id']}: {cluster['auto_label']}")
            print(f"   Accuracy: {cluster['correctness_percentage']:.1f}%")
            print(f"   Total sentences: {cluster['sentence_count']}")
            print(f"   Sample sentences:")
            for sentence in cluster['sample_sentences'][:3]:
                print(f"     - {sentence[:100]}{'...' if len(sentence) > 100 else ''}")

        print("\n")

def main():
    print("Loading cluster files...")

    results = []
    for cluster_file in CLUSTER_FILES:
        cluster_data = load_cluster_file(cluster_file)
        if cluster_data:
            result = analyze_combination(cluster_data)
            if result:
                results.append(result)

    if not results:
        print("No valid cluster data found!")
        return

    print(f"Loaded {len(results)} combinations\n")
    print_results(results)

    print("=" * 80)
    print("ANALYSIS COMPLETE")
    print("=" * 80)

if __name__ == "__main__":
    main()
