#!/usr/bin/env python3
"""
Export best and worst clusters with count, label, and correctness percentage.
"""
import sys
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

import json
import os
import csv

# Configuration
CLUSTER_FILES = [
    "clusters_gpt_4o_mini_GSM8K.json",
    "clusters_gpt_4o_mini_ASDiv.json",
    "clusters_gpt_4o_mini_SVAMP.json",
    "clusters_gpt_3.5_turbo_1106_GSM8K.json",
    "clusters_gpt_3.5_turbo_1106_ASDiv.json",
    "clusters_gpt_3.5_turbo_1106_SVAMP.json"
]

MIN_SENTENCES = 10
TOP_N = 2

def load_cluster_file(file_path):
    """Load cluster data from a single file."""
    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found, skipping...")
        return None

    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def analyze_combination(cluster_data):
    """Analyze clusters for a single model-dataset combination."""
    model = cluster_data['model']
    dataset = cluster_data['dataset']
    clusters = cluster_data['clusters']

    # Filter clusters with minimum sentences
    valid_clusters = [c for c in clusters if c['sentence_count'] >= MIN_SENTENCES]

    if not valid_clusters:
        return None

    # Sort by correctness percentage
    sorted_clusters = sorted(valid_clusters, key=lambda x: x['correctness_percentage'])

    # Get worst and best
    worst_clusters = sorted_clusters[:TOP_N]
    best_clusters = sorted_clusters[-TOP_N:][::-1]

    result = {
        'model': model,
        'dataset': dataset,
        'worst_clusters': [],
        'best_clusters': []
    }

    # Extract worst clusters
    for cluster in worst_clusters:
        result['worst_clusters'].append({
            'count': cluster['sentence_count'],
            'label': cluster.get('auto_label', 'Unknown'),
            'correctness_percentage': cluster['correctness_percentage']
        })

    # Extract best clusters
    for cluster in best_clusters:
        result['best_clusters'].append({
            'count': cluster['sentence_count'],
            'label': cluster.get('auto_label', 'Unknown'),
            'correctness_percentage': cluster['correctness_percentage']
        })

    return result

def main():
    print("Extracting best and worst clusters...")

    results = []
    for file_path in CLUSTER_FILES:
        cluster_data = load_cluster_file(file_path)
        if not cluster_data:
            continue

        result = analyze_combination(cluster_data)
        if result:
            results.append(result)
            print(f"✅ Processed: {result['model']} + {result['dataset']}")

    # Save to JSON
    with open('best_worst_clusters_table.json', 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved: best_worst_clusters_table.json")

    # Save to CSV
    with open('best_worst_clusters_table.csv', 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)
        writer.writerow(['Model', 'Dataset', 'Category', 'Rank', 'Count', 'Label', 'Correctness %'])

        for result in results:
            model = result['model']
            dataset = result['dataset']

            # Worst clusters
            for i, cluster in enumerate(result['worst_clusters'], 1):
                writer.writerow([
                    model,
                    dataset,
                    'Worst',
                    i,
                    cluster['count'],
                    cluster['label'],
                    cluster['correctness_percentage']
                ])

            # Best clusters
            for i, cluster in enumerate(result['best_clusters'], 1):
                writer.writerow([
                    model,
                    dataset,
                    'Best',
                    i,
                    cluster['count'],
                    cluster['label'],
                    cluster['correctness_percentage']
                ])

    print(f"Saved: best_worst_clusters_table.csv")

    # Print summary table
    print("\n" + "=" * 120)
    print("BEST & WORST CLUSTERS SUMMARY")
    print("=" * 120)

    for result in results:
        print(f"\n{'=' * 120}")
        print(f"Model: {result['model']} | Dataset: {result['dataset']}")
        print(f"{'=' * 120}")

        print("\nWORST 2 CLUSTERS:")
        print(f"{'Rank':<6} {'Count':<8} {'Label':<60} {'Correctness %':<15}")
        print("-" * 120)
        for i, cluster in enumerate(result['worst_clusters'], 1):
            print(f"{i:<6} {cluster['count']:<8} {cluster['label']:<60} {cluster['correctness_percentage']:<15.1f}")

        print("\nBEST 2 CLUSTERS:")
        print(f"{'Rank':<6} {'Count':<8} {'Label':<60} {'Correctness %':<15}")
        print("-" * 120)
        for i, cluster in enumerate(result['best_clusters'], 1):
            print(f"{i:<6} {cluster['count']:<8} {cluster['label']:<60} {cluster['correctness_percentage']:<15.1f}")

    print("\n" + "=" * 120)
    print("EXPORT COMPLETE")
    print("=" * 120)

if __name__ == "__main__":
    main()
