#!/usr/bin/env python3
"""
Multi-model candidate profiling script

Runs candidate profiling for multiple models sequentially.
"""

import os
import sys
import json
import time
import logging
import argparse
import subprocess
from pathlib import Path
from typing import List, Dict, Any
from datetime import datetime
from tqdm import tqdm

def setup_logging():
    """Setup logging for the multi-model runner"""
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=logging.INFO,
        format='[%(levelname)s] %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("multi_model_profiling.log", mode='w')
        ]
    )

def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Run candidate profiling for multiple models sequentially")
    parser.add_argument('--models', type=str, nargs='+', required=True, 
                       help='List of model names to profile (e.g., llama3-70b-8192 llama3-8b-8192)')
    parser.add_argument('--api_key', type=str, required=True, 
                       help='API key')
    parser.add_argument('--base_url', type=str, default='https://api.example.com/v1', 
                       help='Base URL for API calls')
    parser.add_argument('--seeds_path', type=str, default='./seeds.json', 
                       help='Path to seeds.json')
    parser.add_argument('--metrics_path', type=str, default='./metrics.yaml', 
                       help='Path to metrics.yaml')
    parser.add_argument('--nums', type=int, default=500, 
                       help='Number of instances to process per model (default: 500, use 0 for all instances)')
    parser.add_argument('--max_retries', type=int, default=3, 
                       help='Maximum number of retries for failed API calls')
    parser.add_argument('--workers', type=int, default=10, 
                       help='Number of parallel workers per model (default: 10)')
    parser.add_argument('--output_dir', type=str, default='./results', 
                       help='Directory to save results for each model (default: ./results)')
    parser.add_argument('--skip_completed', action='store_true', 
                       help='Skip models that already have results files')
    parser.add_argument('--summary_only', action='store_true', 
                       help='Only generate summary, skip individual model runs')
    return parser.parse_args()

def check_dependencies():
    """Check if required files exist"""
    required_files = ['candidate_profiling_groq_sync.py', 'seeds.json', 'metrics.yaml']
    missing_files = []
    
    # Check in MOE-Judge subdirectory since that's where the files are
    moe_judge_dir = os.path.join(os.getcwd(), 'MOE-Judge')
    
    for file in required_files:
        file_path = os.path.join(moe_judge_dir, file)
        if not os.path.exists(file_path):
            missing_files.append(file)
    
    if missing_files:
        logging.error(f"Missing required files: {', '.join(missing_files)}")
        logging.error("Please ensure all required files are in the MOE-Judge subdirectory")
        return False
    
    return True

def get_model_safe_name(model: str) -> str:
    """Convert model name to safe filename"""
    return model.replace('-', '_').replace('/', '_')

def check_model_completed(model: str, output_dir: str, num_instances: int) -> bool:
    """Check if a model has already been completed"""
    model_safe = get_model_safe_name(model)
    results_file = os.path.join(output_dir, f"candidate_profiling_groq_sync_results_{model_safe}_{num_instances}.json")
    summary_file = os.path.join(output_dir, f"candidate_profiling_groq_sync_results_{model_safe}_{num_instances}_summary.json")
    
    return os.path.exists(results_file) and os.path.exists(summary_file)

def run_single_model_profiling(model: str, args, original_dir: str) -> Dict[str, Any]:
    """Run profiling for a single model"""
    model_safe = get_model_safe_name(model)
    output_file = f"candidate_profiling_groq_sync_results_{model_safe}_{args.nums}.json"
    
    # Get absolute paths for files in MOE-Judge subdirectory from original directory
    moe_judge_dir = os.path.join(original_dir, 'MOE-Judge')
    script_path = os.path.join(moe_judge_dir, 'candidate_profiling_groq_sync.py')
    seeds_path = os.path.join(moe_judge_dir, args.seeds_path.lstrip('./'))
    metrics_path = os.path.join(moe_judge_dir, args.metrics_path.lstrip('./'))
    
    # Create command
    cmd = [
        sys.executable, script_path,
        '--model', model,
        '--api_key', args.api_key,
        '--base_url', args.base_url,
        '--seeds_path', seeds_path,
        '--metrics_path', metrics_path,
        '--nums', str(args.nums),
        '--max_retries', str(args.max_retries),
        '--workers', str(args.workers),
        '--output', output_file
    ]
    
    logging.info(f"Starting profiling for model: {model}")
    logging.info(f"Command: {' '.join(cmd)}")
    
    start_time = time.time()
    
    try:
        # Run the profiling script
        result = subprocess.run(
            cmd,
            capture_output=True,
            text=True,
            timeout=None  # No timeout, let it run as long as needed
        )
        
        end_time = time.time()
        duration = end_time - start_time
        
        if result.returncode == 0:
            logging.info(f"Model {model} completed successfully in {duration:.2f} seconds")
            return {
                "model": model,
                "status": "success",
                "duration": duration,
                "returncode": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr
            }
        else:
            logging.error(f"Model {model} failed with return code {result.returncode}")
            logging.error(f"STDOUT: {result.stdout}")
            logging.error(f"STDERR: {result.stderr}")
            return {
                "model": model,
                "status": "failed",
                "duration": duration,
                "returncode": result.returncode,
                "stdout": result.stdout,
                "stderr": result.stderr
            }
    
    except subprocess.TimeoutExpired:
        logging.error(f"Model {model} timed out")
        return {
            "model": model,
            "status": "timeout",
            "duration": None,
            "returncode": None,
            "stdout": "",
            "stderr": "Timeout expired"
        }
    except Exception as e:
        logging.error(f"Model {model} failed with exception: {e}")
        return {
            "model": model,
            "status": "exception",
            "duration": None,
            "returncode": None,
            "stdout": "",
            "stderr": str(e)
        }

def generate_multi_model_summary(results: List[Dict[str, Any]], output_dir: str, num_instances: int):
    """Generate a summary of all model runs"""
    summary = {
        "timestamp": datetime.now().isoformat(),
        "total_models": len(results),
        "successful_models": len([r for r in results if r["status"] == "success"]),
        "failed_models": len([r for r in results if r["status"] != "success"]),
        "model_results": results,
        "statistics": {
            "total_duration": sum(r.get("duration", 0) for r in results if r.get("duration")),
            "average_duration": sum(r.get("duration", 0) for r in results if r.get("duration")) / len([r for r in results if r.get("duration")]) if any(r.get("duration") for r in results) else 0
        }
    }
    
    # Save summary with number of instances in filename
    summary_file = os.path.join(output_dir, f"multi_model_summary_{num_instances}.json")
    with open(summary_file, 'w') as f:
        json.dump(summary, f, indent=2)
    
    # Print summary table
    print("\n" + "="*80)
    print("MULTI-MODEL PROFILING SUMMARY")
    print("="*80)
    print(f"{'Model':<30} {'Status':<12} {'Duration':<12} {'Return Code':<12}")
    print("-"*80)
    
    for result in results:
        model = result["model"]
        status = result["status"]
        duration = f"{result.get('duration', 0):.2f}s" if result.get("duration") else "N/A"
        returncode = str(result.get("returncode", "N/A"))
        
        status_icon = "YES" if status == "success" else "NO"
        print(f"{model:<30} {status_icon} {status:<10} {duration:<12} {returncode:<12}")
    
    print("-"*80)
    successful = summary["successful_models"]
    total = summary["total_models"]
    total_duration = summary["statistics"]["total_duration"]
    avg_duration = summary["statistics"]["average_duration"]
    
    print(f"Success Rate: {successful}/{total} ({successful/total*100:.1f}%)")
    print(f"Total Duration: {total_duration:.2f} seconds ({total_duration/3600:.2f} hours)")
    print(f"Average Duration per Model: {avg_duration:.2f} seconds")
    print(f"Summary saved to: {summary_file}")
    print("="*80)

def main():
    """Main function"""
    setup_logging()
    original_dir = os.getcwd()  # Initialize at the start
    
    try:
        args = parse_args()
        
        # Check dependencies
        if not check_dependencies():
            sys.exit(1)
        
        # Create output directory
        os.makedirs(args.output_dir, exist_ok=True)
        
        # Change to output directory for results
        os.chdir(args.output_dir)
        
        logging.info(f"Starting multi-model profiling for {len(args.models)} models")
        logging.info(f"Models: {', '.join(args.models)}")
        logging.info(f"Output directory: {args.output_dir}")
        logging.info(f"API endpoint: {args.base_url}")
        logging.info(f"Parallel processing: {args.workers} workers per model with global rate limiting on 429 errors")
        
        results = []
        
        # Create progress bar for models
        with tqdm(args.models, desc="Processing models", unit="model") as model_pbar:
            for i, model in enumerate(model_pbar):
                # Update progress bar description
                model_pbar.set_description(f"Processing model {i+1}/{len(args.models)}: {model}")
                
                logging.info(f"\n{'='*60}")
                logging.info(f"Processing model {i+1}/{len(args.models)}: {model}")
                logging.info(f"{'='*60}")
                
                # Check if model is already completed
                if args.skip_completed and check_model_completed(model, ".", args.nums):
                    logging.info(f"Skipping {model} (already completed)")
                    model_pbar.set_postfix({"status": "skipped"})
                    results.append({
                        "model": model,
                        "status": "skipped",
                        "duration": 0,
                        "returncode": 0,
                        "stdout": "Skipped - already completed",
                        "stderr": ""
                    })
                    continue
                
                # Run profiling for this model
                result = run_single_model_profiling(model, args, original_dir)
                results.append(result)
                
                # Update progress bar with status
                if result["status"] == "success":
                    model_pbar.set_postfix({"status": "YES", "duration": f"{result.get('duration', 0):.1f}s"})
                else:
                    model_pbar.set_postfix({"status": "NO", "error": result["status"]})
        
        # Generate summary
        if not args.summary_only:
            generate_multi_model_summary(results, ".", args.nums)
        
        logging.info("Multi-model profiling completed!")
        
    except KeyboardInterrupt:
        logging.info("Interrupted by user.")
        sys.exit(1)
    except Exception as e:
        logging.error(f"Error during multi-model profiling: {e}")
        raise
    finally:
        # Change back to original directory
        os.chdir(original_dir)
        logging.shutdown()

if __name__ == "__main__":
    main() 