#!/usr/bin/env python3
"""
Universal Physics Evaluation Script — supports multiple models and data formats.

Usage examples:
    # Evaluate all available datasets
    python eval_physics.py

    # Evaluate a specific dataset
    python eval_physics.py --dataset panpho_2024

    # Evaluate a JSON-format dataset
    python eval_physics.py --dataset apho_2025_json

    # Use a Judge model for fine-grained marking
    python eval_physics.py --judge-model gpt-4o --api-key YOUR_API_KEY

    # Specify number of parallel processes
    python eval_physics.py --nproc 8

    # Disable Judge model (fine-grained score = 0; use answer-matching only)
    python eval_physics.py --no-judge

    # Evaluate multiple runs and compute statistics
    python eval_physics.py --multi-runs --dataset apho_2025

    # Evaluate multi-run results for all datasets
    python eval_physics.py --multi-runs

    # Specify the log directory
    python eval_physics.py --log-dir my_logs
"""

import argparse
import os
import sys
import logging
from datetime import datetime
from pathlib import Path

# Load API configuration from .env
from dotenv import load_dotenv
load_dotenv('.env')

from universal_physics_evaluator import UniversalPhysicsEvaluator, safe_print

# Global logger
logger = None


def setup_logging(log_dir="logs/logger", dataset_name=None, multi_runs=False):
    """Set up the logging system."""
    global logger

    # Create log directory
    log_path = Path(log_dir)
    log_path.mkdir(parents=True, exist_ok=True)

    # Generate log filename
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    if dataset_name:
        if multi_runs:
            log_filename = f"eval_physics_multi_runs_{dataset_name}_{timestamp}.log"
        else:
            log_filename = f"eval_physics_{dataset_name}_{timestamp}.log"
    else:
        if multi_runs:
            log_filename = f"eval_physics_multi_runs_all_{timestamp}.log"
        else:
            log_filename = f"eval_physics_all_{timestamp}.log"

    log_file = log_path / log_filename

    # Configure logger
    logger = logging.getLogger('eval_physics')
    logger.setLevel(logging.INFO)

    # Clear existing handlers
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # File handler
    file_handler = logging.FileHandler(log_file, encoding='utf-8')
    file_handler.setLevel(logging.INFO)

    # Console handler
    console_handler = logging.StreamHandler(sys.stdout)
    console_handler.setLevel(logging.INFO)

    # Formatters
    formatter = logging.Formatter('%(asctime)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
    file_handler.setFormatter(formatter)

    # Keep console output simple (no timestamp)
    console_formatter = logging.Formatter('%(message)s')
    console_handler.setFormatter(console_formatter)

    # Add handlers
    logger.addHandler(file_handler)
    logger.addHandler(console_handler)

    return log_file


def log_print(*args, **kwargs):
    """Logging-based replacement for safe_print."""
    global logger
    message = ' '.join(str(arg) for arg in args)
    if logger:
        logger.info(message)
    else:
        print(message, **kwargs)


def parse_args():
    """Parse command-line arguments."""
    parser = argparse.ArgumentParser(description="Universal physics evaluation script — supports multiple models and data formats.")

    # Basic arguments
    parser.add_argument(
        "--results-dir",
        type=str,
        default="results",
        help="Inference results directory (default: results)"
    )

    parser.add_argument(
        "--dataset",
        type=str,
        choices=[
            'ipho_2025', 'apho_2025', 'eupho_2025', 'nbpho_2025',
            'panpho_2025', 'panmechanics_2025', 'fma_2025'
        ],
        help="Dataset to evaluate (if not specified, evaluates all available datasets)"
    )

    parser.add_argument(
        "--nproc",
        type=int,
        default=4,
        help="Number of parallel processes (default: 4)"
    )

    # Judge model args
    parser.add_argument(
        "--judge-model",
        type=str,
        help="Judge model name (e.g., gpt-4o, gpt-4-turbo, claude-3-sonnet-20240229)"
    )

    parser.add_argument(
        "--api-key",
        type=str,
        help="API key (optional; can also be set via environment variable)"
    )

    parser.add_argument(
        "--no-judge",
        action="store_true",
        help="Disable Judge model (fine-grained score = 0; coarse-grained evaluation only)"
    )

    parser.add_argument(
        "--multi-runs",
        action="store_true",
        help="Evaluate multiple runs and compute statistics"
    )

    # Other arguments
    parser.add_argument(
        "--output-dir",
        type=str,
        default="evaluation_results",
        help="Output directory for evaluation results (default: evaluation_results)"
    )

    parser.add_argument(
        "--verbose",
        action="store_true",
        help="Verbose mode"
    )

    parser.add_argument(
        "--dry-run",
        action="store_true",
        help="Dry run: only check datasets without performing evaluation"
    )

    parser.add_argument(
        "--log-dir",
        type=str,
        default="logs",
        help="Directory to save log files (default: logs)"
    )

    return parser.parse_args()


def setup_environment(args):
    """Set up the runtime environment."""
    # Set API key
    if args.api_key:
        os.environ["OPENAI_API_KEY"] = args.api_key
        log_print("✅ API key has been set")

    # Check results directory
    results_dir = Path(args.results_dir)
    if not results_dir.exists():
        log_print(f"❌ Results directory does not exist: {results_dir}")
        sys.exit(1)

    # Check output directory
    if args.output_dir:
        args.output_dir = os.path.join(args.output_dir, "")
        log_print(f"✅ Output directory specified: {args.output_dir}")

    # Create output directory
    output_dir = Path(args.output_dir).resolve()
    output_dir.mkdir(parents=True, exist_ok=True)
    log_print(f"📁 Output directory: {output_dir}")


def build_judge_kwargs(args):
    """Build kwargs for the Judge model."""
    judge_kwargs = {}

    if args.no_judge:
        log_print("⚠️  Judge model disabled; fine-grained score will be 0; coarse-grained evaluation will still run")
        return judge_kwargs

    if args.judge_model:
        judge_kwargs['model'] = args.judge_model
        log_print(f"🤖 Using Judge model: {args.judge_model}")

        # Check API key
        if not os.getenv("OPENAI_API_KEY") and not args.api_key:
            log_print("⚠️  Warning: API key not set; the Judge model may not work")
            log_print("   Please set via --api-key or OPENAI_API_KEY environment variable")
    else:
        log_print("ℹ️  No Judge model specified; fine-grained score will be 0; coarse-grained evaluation only")

    if args.nproc:
        judge_kwargs['nproc'] = args.nproc

    return judge_kwargs


def main():
    """Main entry point."""
    args = parse_args()

    # Initialize logging
    log_file = setup_logging(
        log_dir=args.log_dir,
        dataset_name=args.dataset,
        multi_runs=args.multi_runs
    )

    # Startup info
    log_print("🚀 Universal Physics Evaluation System started")
    log_print("=" * 60)
    log_print(f"📂 Results directory: {args.results_dir}")
    log_print(f"📊 Parallel processes: {args.nproc}")
    log_print(f"💾 Output directory: {args.output_dir}")
    log_print(f"📝 Log file: {log_file}")
    if args.dataset:
        log_print(f"🎯 Selected dataset: {args.dataset}")
    else:
        log_print(f"🎯 Evaluation mode: all available datasets")
    log_print("=" * 60)

    # Environment setup
    setup_environment(args)

    # Build Judge kwargs
    judge_kwargs = build_judge_kwargs(args)

    # Initialize evaluator
    try:
        evaluator = UniversalPhysicsEvaluator(
            results_dir=args.results_dir,
            output_dir=args.output_dir,
            dataset_name=args.dataset,
            nproc=args.nproc
        )
        log_print("✅ Evaluator initialized successfully")
    except Exception as e:
        log_print(f"❌ Failed to initialize evaluator: {e}")
        sys.exit(1)

    # Dry-run mode: only check datasets
    if args.dry_run:
        log_print("\n🔍 Dry-run mode: checking available datasets...")
        available_datasets = evaluator.detect_available_datasets()
        log_print(f"\n📊 Found {len(available_datasets)} available datasets:")
        for dataset_key in available_datasets:
            config = evaluator.DATASET_CONFIGS[dataset_key]
            log_print(f"   ✓ {config['display_name']} ({dataset_key})")
        log_print("\n✅ Dry run complete")
        return

    # Begin evaluation
    try:
        if args.multi_runs:
            # Multi-run evaluation
            if args.dataset:
                # Multi-run for a single dataset
                log_print(f"\n🔄 Starting multi-run evaluation: {args.dataset}")

                # Check if multiple runs exist
                if not evaluator.has_multiple_runs(args.dataset):
                    log_print(f"⚠️  Dataset {args.dataset} has no multi-run results")
                    return

                multi_run_results = evaluator.evaluate_multiple_runs(args.dataset, judge_kwargs)

                if multi_run_results:
                    overall = multi_run_results['overall_statistics']
                    log_print(f"\n🎉 Multi-run evaluation complete!")
                    log_print(f"🔄 Number of runs: {overall['num_runs']}")
                    log_print(f"📈 Mean score rate: {overall['mean_score_rate']:.2f}% ± {overall['std_score_rate']:.2f}%")
                else:
                    log_print(f"❌ Multi-run evaluation failed for dataset {args.dataset}")
            else:
                # Multi-run for all datasets
                log_print(f"\n🌟 Starting multi-run evaluation for all datasets...")
                available_datasets = evaluator.detect_available_datasets()

                multi_run_datasets = []
                for dataset_key in available_datasets:
                    if evaluator.has_multiple_runs(dataset_key):
                        multi_run_datasets.append(dataset_key)

                if not multi_run_datasets:
                    log_print(f"❌ No datasets with multi-run results were found")
                    return

                log_print(f"📊 Found {len(multi_run_datasets)} datasets with multiple runs")

                all_multi_run_results = {}
                for dataset_key in multi_run_datasets:
                    log_print(f"\n{'='*60}")
                    log_print(f"🔄 Evaluating multiple runs: {evaluator.DATASET_CONFIGS[dataset_key]['display_name']}")
                    log_print(f"{'='*60}")

                    try:
                        multi_run_results = evaluator.evaluate_multiple_runs(dataset_key, judge_kwargs)
                        all_multi_run_results[dataset_key] = multi_run_results

                        if multi_run_results:
                            overall = multi_run_results['overall_statistics']
                            log_print(f"✅ Completed: mean score rate {overall['mean_score_rate']:.2f}% ± {overall['std_score_rate']:.2f}%")

                    except Exception as e:
                        log_print(f"❌ Multi-run evaluation failed for {dataset_key}: {e}")
                        all_multi_run_results[dataset_key] = None

                # Save summary of all multi-run results
                if all_multi_run_results:
                    evaluator._save_all_multi_run_summary(all_multi_run_results)

        else:
            # Regular evaluation
            if args.dataset:
                # Single dataset
                log_print(f"\n🎯 Starting evaluation for dataset: {args.dataset}")
                results = evaluator.evaluate_dataset(args.dataset, judge_kwargs)

                if results:
                    config = evaluator.DATASET_CONFIGS[args.dataset]
                    log_print(f"\n✅ {config['display_name']} evaluation complete!")
                    log_print(f"🏆 Total score: {results['total_score']:.2f} / {results['max_possible_score']:.2f} ({results['score_rate']:.2f}%)")
                else:
                    log_print(f"❌ Evaluation failed for dataset {args.dataset}")
            else:
                # All datasets
                log_print(f"\n🌟 Starting evaluation for all available datasets...")
                all_results = evaluator.evaluate_all_datasets(judge_kwargs)

                if all_results:
                    log_print(f"\n🎉 All datasets evaluated!")
                    successful_count = sum(1 for r in all_results.values() if r is not None)
                    log_print(f"📊 Successfully evaluated {successful_count}/{len(all_results)} datasets")
                else:
                    log_print(f"❌ Failed to evaluate any dataset")

    except KeyboardInterrupt:
        log_print(f"\n⏹️  Evaluation interrupted by user")
        sys.exit(1)
    except Exception as e:
        log_print(f"\n❌ An error occurred during evaluation: {e}")
        import traceback
        if args.verbose:
            log_print(f"Detailed traceback:\n{traceback.format_exc()}")
        sys.exit(1)

    log_print(f"\n🎯 Evaluation complete! Results saved to: {args.output_dir}")
    log_print(f"📝 Full log saved to: {log_file}")


if __name__ == "__main__":
    main()
