#!/usr/bin/env python3
"""
Simple script to combine results from parallel data processing tasks.
"""

import torch
import os
import glob

def combine_results():
    """Combine results from all parallel tasks"""
    # Find all logistics.pt files from different tasks
    logistics_files = glob.glob("./output_task_*/logistics.pt")

    if not logistics_files:
        print("No logistics files found!")
        return

    print(f"Found {len(logistics_files)} task results to combine")

    # Combine all results
    total_original_correct = 0
    total_optimized_correct = 0
    total_samples = 0
    total_update_count = 0
    total_original_length = 0
    total_optimized_length = 0
    total_fitten_length = 0

    for file_path in sorted(logistics_files):
        try:
            data = torch.load(file_path)
            total_original_correct += data["original_correct"]
            total_optimized_correct += data["optimized_correct"]
            total_samples += data["total"]
            total_update_count += data["update_count"]
            total_original_length += data["original_length"]
            total_optimized_length += data["optimized_length"]
            total_fitten_length += data["fitten_length"]
            print(f"Task {file_path}: {data['total']} samples processed")
        except Exception as e:
            print(f"Error loading {file_path}: {e}")

    # Calculate final metrics
    original_accuracy = total_original_correct / total_samples if total_samples > 0 else 0
    optimized_accuracy = total_optimized_correct / total_samples if total_samples > 0 else 0
    avg_update_length = total_update_count / total_samples if total_samples > 0 else 0
    avg_original_length = total_original_length / total_samples if total_samples > 0 else 0
    avg_optimized_length = total_optimized_length / total_samples if total_samples > 0 else 0
    avg_fitten_length = total_fitten_length / total_samples if total_samples > 0 else 0

    # Save combined results
    combined_results = {
        "original_correct": total_original_correct,
        "optimized_correct": total_optimized_correct,
        "total": total_samples,
        "update_count": total_update_count,
        "original_length": total_original_length,
        "optimized_length": total_optimized_length,
        "fitten_length": total_fitten_length,
        "original_accuracy": original_accuracy,
        "optimized_accuracy": optimized_accuracy,
        "avg_update_length": avg_update_length,
        "avg_original_length": avg_original_length,
        "avg_optimized_length": avg_optimized_length,
        "avg_fitten_length": avg_fitten_length
    }

    torch.save(combined_results, "./combined_results.pt")

    print("\n" + "="*50)
    print("COMBINED RESULTS:")
    print("="*50)
    print(".4f")
    print(".4f")
    print(".4f")
    print(".4f")
    print(".4f")
    print(".4f")
    print("="*50)

if __name__ == "__main__":
    combine_results()
