#!/usr/bin/env python3
"""
Merge existing database prediction files into a single merged predictions file
"""

import json
import os
from pathlib import Path
from datetime import datetime

def merge_existing_predictions(predictions_dir: str, output_file: str):
    """Merge all individual database prediction files"""
    predictions_path = Path(predictions_dir)
    all_predictions = {}
    total_token_usage = {
        'input_tokens': 0,
        'output_tokens': 0,
        'cache_creation_tokens': 0,
        'cache_read_tokens': 0,
        'estimated_cost': 0.0
    }
    
    db_summaries = []
    databases = []
    total_questions = 0
    total_time = 0.0
    
    # Find all individual database prediction files
    db_files = [f for f in predictions_path.glob("*_predictions.json") if not f.name.startswith("merged")]
    
    print(f"Found {len(db_files)} database prediction files:")
    for db_file in sorted(db_files):
        print(f"  - {db_file.name}")
    
    # Load and merge each database file
    for db_file in sorted(db_files):
        try:
            with open(db_file, 'r') as f:
                db_result = json.load(f)
            
            # Extract database name from filename
            db_name = db_file.stem.replace('_predictions', '')
            databases.append(db_name)
            
            # Merge predictions
            db_predictions = db_result.get('predictions', {})
            all_predictions.update(db_predictions)
            total_questions += len(db_predictions)
            
            # Accumulate token usage
            db_tokens = db_result.get('token_usage', {})
            total_token_usage['input_tokens'] += db_tokens.get('input_tokens', 0)
            total_token_usage['output_tokens'] += db_tokens.get('output_tokens', 0) 
            total_token_usage['cache_creation_tokens'] += db_tokens.get('cache_creation_input_tokens', 0)
            total_token_usage['cache_read_tokens'] += db_tokens.get('cache_read_input_tokens', 0)
            total_token_usage['estimated_cost'] += db_tokens.get('total_cost', 0)
            
            # Database timing info
            db_metadata = db_result.get('metadata', {})
            processing_time = db_metadata.get('processing_time_seconds', 0)
            total_time += processing_time
            
            db_summaries.append({
                'database': db_name,
                'questions': len(db_predictions),
                'time_seconds': processing_time,
                'avg_per_question': processing_time / len(db_predictions) if db_predictions else 0,
                'cost': db_tokens.get('total_cost', 0),
                'status': 'completed'
            })
            
            print(f"  ✅ {db_name}: {len(db_predictions)} questions")
            
        except Exception as e:
            print(f"  ❌ Error processing {db_file}: {e}")
    
    # Sort predictions by question ID for consistent ordering
    sorted_predictions = {}
    for question_id in sorted(all_predictions.keys(), key=int):
        sorted_predictions[question_id] = all_predictions[question_id]
    
    # Create merged metadata
    timestamp = datetime.now().isoformat()
    merged_metadata = {
        'experiment': {
            'description': 'Complete initial_dev_baseline merged from individual database files',
            'timestamp': timestamp,
            'model': 'claude-sonnet-4-20250514',
            'total_questions': total_questions,
            'databases': sorted(databases),
            'limited_per_db': None,
            'parallel_processing': {
                'max_concurrent_dbs': 3,
                'successful_dbs': len(databases),
                'failed_dbs': 0
            }
        },
        'timing': {
            'total_seconds': total_time,
            'total_minutes': total_time / 60,
            'avg_per_question': total_time / total_questions if total_questions > 0 else 0,
            'by_database': db_summaries
        },
        'token_usage': total_token_usage,
        'config': {
            'dev_data_path': 'benchmark_resources/datasets/dev/dev_20240627/dev.json',
            'db_root_path': 'benchmark_resources/datasets/dev/dev_20240627/dev_databases',
            'single_db_filter': None,
            'output_directory': predictions_dir
        }
    }
    
    # Create final merged output
    merged_output = {
        'metadata': merged_metadata,
        'predictions': sorted_predictions
    }
    
    # Save merged results
    with open(output_file, 'w') as f:
        json.dump(merged_output, f, indent=2)
    
    print(f"\n✅ Created merged predictions file: {output_file}")
    print(f"📊 Summary:")
    print(f"  - Total databases: {len(databases)}")
    print(f"  - Total questions: {total_questions:,}")
    print(f"  - Total time: {total_time:.1f}s ({total_time/60:.1f}m)")
    print(f"  - Total cost: ${total_token_usage['estimated_cost']:.4f}")
    print(f"  - Avg per question: {total_time/total_questions:.2f}s")
    
    return merged_output

if __name__ == "__main__":
    predictions_dir = "predictions/initial_dev_baseline"
    output_file = os.path.join(predictions_dir, "merged_predictions.json")
    
    merge_existing_predictions(predictions_dir, output_file)