#!/usr/bin/env python3
"""
Merge results from batch processing into a single torch file.
"""

import argparse
import os
import json
import torch
from glob import glob


def main(args):
    print(f"Merging results from: {args.results_dir}")
    print(f"Using metadata from: {args.batch_dir}")
    
    # Load metadata
    metadata_path = os.path.join(args.batch_dir, "metadata.json")
    if not os.path.exists(metadata_path):
        raise FileNotFoundError(f"metadata.json not found in {args.batch_dir}")
    
    with open(metadata_path, "r") as f:
        metadata = json.load(f)
    
    num_batches = metadata["num_batches"]
    total_samples = metadata["total_samples"]
    batch_size = metadata["batch_size"]
    
    print(f"Dataset info:")
    print(f"  Total samples: {total_samples}")
    print(f"  Batch size: {batch_size}")
    print(f"  Number of batches: {num_batches}")
    
    # Find and load all batch results
    all_results = []
    missing_batches = []
    loaded_samples = 0
    
    print("\nLoading batch results...")
    for i in range(num_batches):
        result_path = os.path.join(args.results_dir, f"batch_{i:04d}.torch")
        
        if os.path.exists(result_path):
            try:
                batch_results = torch.load(result_path, weights_only=False)
                
                if not isinstance(batch_results, list):
                    raise ValueError(f"Batch {i} results are not a list")
                
                all_results.extend(batch_results)
                loaded_samples += len(batch_results)
                print(f"  Loaded batch {i}: {len(batch_results)} samples")
                
            except Exception as e:
                print(f"  ERROR loading batch {i}: {e}")
                missing_batches.append(i)
        else:
            missing_batches.append(i)
    
    # Report missing batches
    if missing_batches:
        print(f"\nWarning: Missing batches: {missing_batches}")
        print(f"Missing {len(missing_batches)} out of {num_batches} batches")
        
        if not args.allow_missing:
            raise ValueError(
                f"Some batches are missing. Use --allow-missing to continue anyway.\n"
                f"Missing batches: {missing_batches}"
            )
    
    print(f"\nMerge summary:")
    print(f"  Expected samples: {total_samples}")
    print(f"  Loaded samples: {loaded_samples}")
    print(f"  Missing samples: {total_samples - loaded_samples}")
    
    if loaded_samples == 0:
        raise ValueError("No samples loaded! Check that batch results exist.")
    
    # Validate result structure (check first result)
    if all_results and args.validate:
        print("\nValidating result structure...")
        first_result = all_results[0]
        
        required_keys = ["input", "gold_answer", "scores", "stats"]
        for key in required_keys:
            if key not in first_result:
                print(f"  Warning: Missing key '{key}' in results")
        
        # Check a few more results to ensure consistency
        sample_size = min(10, len(all_results))
        for i in range(1, sample_size):
            if set(all_results[i].keys()) != set(first_result.keys()):
                print(f"  Warning: Result {i} has different keys than result 0")
    
    # Save merged results
    print(f"\nSaving merged results to: {args.output_path}")
    os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
    torch.save(all_results, args.output_path)
    
    # Verify saved file
    file_size_mb = os.path.getsize(args.output_path) / (1024 * 1024)
    print(f"Saved successfully! File size: {file_size_mb:.2f} MB")
    
    # Print final summary
    print("\n" + "="*50)
    print("MERGE COMPLETE")
    print(f"Total samples merged: {len(all_results)}")
    print(f"Output file: {args.output_path}")
    print("="*50)


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Merge batch processing results")
    parser.add_argument("--batch-dir", type=str, required=True,
                        help="Directory containing metadata.json from shard_dataset.py")
    parser.add_argument("--results-dir", type=str, required=True,
                        help="Directory containing batch result files (batch_XXXX.torch)")
    parser.add_argument("--output-path", type=str, required=True,
                        help="Path for the merged output file")
    parser.add_argument("--allow-missing", action="store_true",
                        help="Continue even if some batches are missing")
    parser.add_argument("--validate", action="store_true", default=True,
                        help="Validate result structure (default: True)")
    parser.add_argument("--no-validate", dest="validate", action="store_false",
                        help="Skip result validation")
    
    args = parser.parse_args()
    main(args)