#!/usr/bin/env python3
"""
Save top 2 best and worst clusters for each model-dataset combination.
"""
import sys
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

import json
import os
import csv

# Configuration
CLUSTER_FILES = [
    "clusters_gpt_4o_mini_GSM8K.json",
    "clusters_gpt_4o_mini_ASDiv.json",
    "clusters_gpt_4o_mini_SVAMP.json",
    "clusters_gpt_3.5_turbo_1106_GSM8K.json",
    "clusters_gpt_3.5_turbo_1106_ASDiv.json",
    "clusters_gpt_3.5_turbo_1106_SVAMP.json"
]

MIN_SENTENCES = 10  # Only consider clusters with at least this many sentences
TOP_N = 2  # Show top 2 best and 2 worst

def load_cluster_file(file_path):
    """Load cluster data from a single file."""
    if not os.path.exists(file_path):
        print(f"Warning: {file_path} not found, skipping...")
        return None

    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

def analyze_combination(cluster_data):
    """Analyze clusters for a single model-dataset combination."""
    model = cluster_data['model']
    dataset = cluster_data['dataset']
    clusters = cluster_data['clusters']

    # Filter clusters with minimum sentences
    valid_clusters = [c for c in clusters if c['sentence_count'] >= MIN_SENTENCES]

    if not valid_clusters:
        return None

    # Sort by correctness percentage
    sorted_clusters = sorted(valid_clusters, key=lambda x: x['correctness_percentage'])

    # Get worst and best
    worst_clusters = sorted_clusters[:TOP_N]
    best_clusters = sorted_clusters[-TOP_N:][::-1]  # Reverse to show highest first

    result = {
        'model': model,
        'dataset': dataset,
        'total_clusters_analyzed': len(valid_clusters),
        'worst_clusters': [],
        'best_clusters': []
    }

    # Extract worst clusters
    for cluster in worst_clusters:
        result['worst_clusters'].append({
            'cluster_name': f"Cluster {cluster['cluster_id']}",
            'label': cluster.get('auto_label', 'Unknown'),
            'correctness_percentage': cluster['correctness_percentage'],
            'sentence_count': cluster['sentence_count']
        })

    # Extract best clusters
    for cluster in best_clusters:
        result['best_clusters'].append({
            'cluster_name': f"Cluster {cluster['cluster_id']}",
            'label': cluster.get('auto_label', 'Unknown'),
            'correctness_percentage': cluster['correctness_percentage'],
            'sentence_count': cluster['sentence_count']
        })

    return result

def save_to_json(results, output_file):
    """Save results to JSON file."""
    with open(output_file, 'w', encoding='utf-8') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved JSON: {output_file}")

def save_to_csv(results, output_file):
    """Save results to CSV file."""
    fieldnames = ['model', 'dataset', 'category', 'rank', 'cluster_name',
                  'label', 'correctness_percentage', 'sentence_count']

    with open(output_file, 'w', encoding='utf-8', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=fieldnames)
        writer.writeheader()

        for result in results:
            model = result['model']
            dataset = result['dataset']

            # Write worst clusters
            for i, cluster in enumerate(result['worst_clusters'], 1):
                writer.writerow({
                    'model': model,
                    'dataset': dataset,
                    'category': 'Worst',
                    'rank': i,
                    'cluster_name': cluster['cluster_name'],
                    'label': cluster['label'],
                    'correctness_percentage': cluster['correctness_percentage'],
                    'sentence_count': cluster['sentence_count']
                })

            # Write best clusters
            for i, cluster in enumerate(result['best_clusters'], 1):
                writer.writerow({
                    'model': model,
                    'dataset': dataset,
                    'category': 'Best',
                    'rank': i,
                    'cluster_name': cluster['cluster_name'],
                    'label': cluster['label'],
                    'correctness_percentage': cluster['correctness_percentage'],
                    'sentence_count': cluster['sentence_count']
                })

    print(f"Saved CSV: {output_file}")

def main():
    print("=" * 80)
    print("BEST & WORST CLUSTERS EXTRACTION")
    print("=" * 80)
    print(f"Minimum sentences per cluster: {MIN_SENTENCES}")
    print(f"Top N best/worst: {TOP_N}\n")

    results = []

    for file_path in CLUSTER_FILES:
        cluster_data = load_cluster_file(file_path)
        if not cluster_data:
            continue

        result = analyze_combination(cluster_data)
        if result:
            results.append(result)
            print(f"Processed: {result['model']} + {result['dataset']}")
            print(f"   Clusters analyzed: {result['total_clusters_analyzed']}")

    if not results:
        print("\nNo results found.")
        return

    print(f"\n{'=' * 80}")
    print(f"Total combinations processed: {len(results)}")
    print(f"{'=' * 80}")

    # Save to files
    save_to_json(results, "best_worst_clusters.json")
    save_to_csv(results, "best_worst_clusters.csv")

    # Print summary
    print(f"\n{'=' * 80}")
    print("SUMMARY")
    print(f"{'=' * 80}")

    for result in results:
        print(f"\n{result['model']} + {result['dataset']}:")
        print(f"  Worst cluster: {result['worst_clusters'][0]['label']} ({result['worst_clusters'][0]['correctness_percentage']:.1f}%)")
        print(f"  Best cluster: {result['best_clusters'][0]['label']} ({result['best_clusters'][0]['correctness_percentage']:.1f}%)")

    print(f"\n{'=' * 80}")
    print("EXPORT COMPLETE")
    print(f"{'=' * 80}")

if __name__ == "__main__":
    main()
