import json
import csv
import sys
from collections import Counter, OrderedDict

# Fix Unicode encoding issues on Windows
if sys.platform == 'win32':
    sys.stdout.reconfigure(encoding='utf-8')

# List of failure analysis files
FAILURE_FILES = [
    ("failure_analysis_gpt_4o_mini_GSM8K.json", "gpt-4o-mini\nGSM8K"),
    ("failure_analysis_gpt_4o_mini_ASDiv.json", "gpt-4o-mini\nASDiv"),
    ("failure_analysis_gpt_4o_mini_SVAMP.json", "gpt-4o-mini\nSVAMP"),
    ("failure_analysis_gpt_3.5_turbo_1106_GSM8K.json", "gpt-3.5-turbo\nGSM8K"),
    ("failure_analysis_gpt_3.5_turbo_1106_ASDiv.json", "gpt-3.5-turbo\nASDiv"),
    ("failure_analysis_gpt_3.5_turbo_1106_SVAMP.json", "gpt-3.5-turbo\nSVAMP"),
]

def analyze_error_distribution(file_path):
    """Analyze error type distribution for a specific file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)

    error_types = []
    for entry in data:
        first_error = entry.get('first_error')
        if isinstance(first_error, dict):
            error_type = first_error.get('error_type', 'Unknown')
        elif isinstance(first_error, str):
            error_type = 'Uncategorized by Analyst'
        else:
            error_type = 'Unknown'
        error_types.append(error_type)

    # Count error types
    error_counts = Counter(error_types)
    total = len(error_types)

    return error_counts, total

def main():
    print("Creating combined error distribution table (percentages only)...")

    # Collect all unique error types
    all_error_types = set()
    combo_data = []

    for file_path, combo_name in FAILURE_FILES:
        try:
            error_counts, total = analyze_error_distribution(file_path)
            combo_data.append({
                'name': combo_name,
                'counts': error_counts,
                'total': total
            })
            all_error_types.update(error_counts.keys())
            print(f"Processed: {combo_name.replace(chr(10), ' + ')}")
        except Exception as e:
            print(f"⚠️  Error processing {file_path}: {e}")

    # Sort error types by overall frequency (descending)
    overall_counts = Counter()
    for combo in combo_data:
        overall_counts.update(combo['counts'])

    sorted_error_types = [et for et, _ in sorted(overall_counts.items(), key=lambda x: x[1], reverse=True)]

    # Create CSV with combined table (percentages only)
    csv_filename = "error_distribution_combined_percentages.csv"
    with open(csv_filename, 'w', encoding='utf-8', newline='') as f:
        writer = csv.writer(f)

        # Header row: Combination names
        header = ['Error Category']
        for combo in combo_data:
            header.append(combo['name'].replace('\n', ' + '))
        writer.writerow(header)

        # Data rows
        for error_type in sorted_error_types:
            row = [error_type]
            for combo in combo_data:
                count = combo['counts'].get(error_type, 0)
                percentage = (count / combo['total']) * 100 if combo['total'] > 0 else 0
                row.append(f"{percentage:.1f}%")
            writer.writerow(row)

        # Total row
        writer.writerow([])
        total_row = ['Total Failures (N)']
        for combo in combo_data:
            total_row.append(combo['total'])
        writer.writerow(total_row)

    print(f"\nSaved: {csv_filename}")

    # Also create a JSON version
    json_data = {
        'error_categories': sorted_error_types,
        'combinations': []
    }

    for combo in combo_data:
        combo_entry = {
            'name': combo['name'].replace('\n', ' + '),
            'total_failures': combo['total'],
            'errors': []
        }
        for error_type in sorted_error_types:
            count = combo['counts'].get(error_type, 0)
            percentage = (count / combo['total']) * 100 if combo['total'] > 0 else 0
            combo_entry['errors'].append({
                'category': error_type,
                'percentage': round(percentage, 1)
            })
        json_data['combinations'].append(combo_entry)

    json_filename = "error_distribution_combined_percentages.json"
    with open(json_filename, 'w', encoding='utf-8') as f:
        json.dump(json_data, f, indent=2)

    print(f"Saved: {json_filename}")

    # Print preview
    print("\n" + "=" * 100)
    print("PREVIEW OF COMBINED TABLE (PERCENTAGES ONLY)")
    print("=" * 100)
    print()

    # Calculate column widths
    col_width = 18

    # Header
    header_str = f"{'Error Category':<30}"
    for combo in combo_data:
        combo_name = combo['name'].replace('\n', ' + ')
        header_str += f"| {combo_name:^{col_width}} "
    print(header_str)
    print("-" * 100)

    # Data rows
    for error_type in sorted_error_types:
        row_str = f"{error_type:<30}"
        for combo in combo_data:
            count = combo['counts'].get(error_type, 0)
            percentage = (count / combo['total']) * 100 if combo['total'] > 0 else 0
            row_str += f"| {percentage:>{col_width-1}.1f}% "
        print(row_str)

    # Total row
    print("-" * 100)
    total_str = f"{'Total Failures (N)':<30}"
    for combo in combo_data:
        total_str += f"| {combo['total']:>{col_width}} "
    print(total_str)

    print("\n" + "=" * 100)
    print("EXPORT COMPLETE")
    print("=" * 100)

if __name__ == "__main__":
    main()
