#!/usr/bin/env python3
"""
Multimodal Model Physics Property Analysis - Complete Test Runner

Usage Instructions:
- Supports GPT and Qwen models for physics property analysis
- Supports setting sample sizes for absolute value tests and relative comparison tests
- Absolute value tests: Test N videos per test set
- Relative comparison tests: Test N pairs of videos per test set
"""

import os
import sys
import argparse
from datetime import datetime
from gpt_physics_analyzer import GPTPhysicsAnalyzer


def get_default_output_dir(model_version: str, concat_mode: bool = False, frame_index: bool = False, icl_example: bool = False, follow_oracle_test: bool = False) -> str:
    """
    Generate default output directory name based on model version and enabled features
    
    Args:
        model_version: Model version (supports GPT, Qwen, etc.)
        concat_mode: Whether to enable concatenation mode
        frame_index: Whether to enable frame indexing
        icl_example: Whether to enable ICL examples
        follow_oracle_test: Whether to enable step-by-step guidance
        
    Returns:
        Default output directory name (with feature suffixes)
    """
    # Convert model name to directory-friendly format
    model_name = model_version.lower().replace("-", "_").replace(".", "_")
    
    # Generate base directory name based on model prefix
    if model_name.startswith("gpt"):
        if "mini" in model_name:
            base_dir = "gpt_mini_physics_analysis_results"
        elif "turbo" in model_name:
            base_dir = "gpt_turbo_physics_analysis_results"
        else:
            base_dir = "gpt_physics_analysis_results"
    elif model_name.startswith("qwen"):
        if "max" in model_name:
            base_dir = "qwen_max_physics_analysis_results"
        elif "plus" in model_name:
            base_dir = "qwen_plus_physics_analysis_results"
        elif "turbo" in model_name:
            base_dir = "qwen_turbo_physics_analysis_results"
        else:
            base_dir = "qwen_physics_analysis_results"
    elif model_name.startswith("gemini"):
        base_dir = "gemini_physics_analysis_results"
    elif model_name.startswith("claude"):
        base_dir = "claude_physics_analysis_results"
    else:
        # Use generic format for other models
        model_prefix = model_name.split("_")[0]
        base_dir = f"{model_prefix}_physics_analysis_results"
    
    # Add suffixes based on enabled features
    suffixes = []
    if concat_mode:
        suffixes.append("concat")
    if frame_index:
        suffixes.append("frameindex")
    if icl_example:
        suffixes.append("icl")
    if follow_oracle_test:
        suffixes.append("oracle")
    
    if suffixes:
        # Add suffixes to base directory name
        suffix_str = "_".join(suffixes)
        return f"{base_dir}_{suffix_str}"
    else:
        return base_dir


def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Multimodal Model Physics Property Analysis - Complete Test Runner")
    
    parser.add_argument(
        "--absolute-samples", 
        type=int, 
        default=50,
        help="Number of samples per test set for absolute value tests (default: 50)"
    )
    
    parser.add_argument(
        "--relative-pairs", 
        type=int, 
        default=50,
        help="Number of video pairs per test set for relative comparison tests (default: 50)"
    )
    
    parser.add_argument(
        "--model-version",
        type=str,
        default="gpt-4o",
        help="Model version (supports: gpt-4o, qwen-vl-max, qwen-vl-plus, qwen-vl-turbo, default: gpt-4o)"
    )
    
    parser.add_argument(
        "--frame-width", 
        type=int, 
        default=640,
        help="Video frame width (default: 640)"
    )
    
    parser.add_argument(
        "--frame-height", 
        type=int, 
        default=480,
        help="Video frame height (default: 480)"
    )
    
    parser.add_argument(
        "--output-dir", 
        type=str, 
        default=None,  # Changed to None, will be generated dynamically based on model_version
        help="Output directory for results (default: auto-generated based on model version)"
    )
    
    parser.add_argument(
        "--dry-run", 
        action="store_true",
        help="Only show test plan without executing actual tests"
    )
    
    parser.add_argument(
        "--auto-confirm", 
        action="store_true",
        help="Auto-confirm tests without user input"
    )
    
    parser.add_argument(
        "--concat-mode", 
        action="store_true",
        help="Use concatenation mode for relative comparison (add black frames between two videos)"
    )
    
    parser.add_argument(
        "--properties",
        type=str,
        nargs="*",
        default=None,
        choices=["friction", "viscosity", "elasticity"],
        help="Specify physics properties to test (options: friction, viscosity, elasticity). Default tests all properties. Example: --properties friction viscosity"
    )
    
    parser.add_argument(
        "--frame-index",
        action="store_true",
        help="Add explicit frame index labels for each video frame (e.g., frame1: image, frame2: image)"
    )
    
    parser.add_argument(
        "--icl-example",
        action="store_true",
        help="Enable in-context learning, sampling 3 examples from training set as input-output pairs"
    )
    
    parser.add_argument(
        "--follow-oracle-test",
        action="store_true",
        help="Enable step-by-step guidance, telling the model specific analysis methods and visual cues"
    )
    
    # Added retry configuration parameters
    parser.add_argument(
        "--max-retries",
        type=int,
        default=5,
        help="Maximum number of retries for each API call (default: 5)"
    )
    
    parser.add_argument(
        "--retry-delay",
        type=int,
        default=15,
        help="Retry interval in seconds (default: 15)"
    )
    
    parser.add_argument(
        "--validation-retries",
        type=int,
        default=3,
        help="Number of validation retries, will retry API call if valid values cannot be extracted (default: 3)"
    )
    
    parser.add_argument(
        "--aggressive-retry",
        action="store_true",
        help="Enable aggressive retry mode: increase retry counts and validation counts to get more valid data on first run"
    )
    
    return parser.parse_args()



def main():
    """Main function"""
    
    # Parse arguments
    args = parse_args()
    
    # Generate output directory based on model version and features (if not specified by user)
    if args.output_dir is None:
        args.output_dir = get_default_output_dir(
            model_version=args.model_version,
            concat_mode=args.concat_mode,
            frame_index=args.frame_index,
            icl_example=args.icl_example,
            follow_oracle_test=args.follow_oracle_test
        )
    
    # Process and validate property parameters
    if args.properties is not None:
        if len(args.properties) == 0:
            # If --properties is specified but no parameters are provided, show help
            print("Error: --properties parameter requires at least one physics property")
            print("Available properties: friction, viscosity, elasticity")
            print("Example: --properties friction viscosity")
            return 1
        
        # Validate property names (although argparse already does choices validation, we double-check here)
        valid_properties = ["friction", "viscosity", "elasticity"]
        invalid_properties = [p for p in args.properties if p not in valid_properties]
        if invalid_properties:
            print(f"Error: Invalid property names: {', '.join(invalid_properties)}")
            print(f"Available properties: {', '.join(valid_properties)}")
            return 1
        
        print(f"Testing specified properties: {', '.join(args.properties)}")
    else:
        print("Testing all physics properties")
    
    # Create analyzer
    print("Initializing multimodal model physics property analyzer...")
    try:
        analyzer = GPTPhysicsAnalyzer()
        print(f"Loaded configuration: {len(analyzer.config)} test versions")
        print("✓ Analyzer initialization successful")
    except Exception as e:
        print(f"✗ Analyzer initialization failed: {e}")
        return 1
    
    # Handle aggressive retry mode
    if args.aggressive_retry:
        print("🚀 Aggressive retry mode enabled")
        # In aggressive mode, use higher default values if user hasn't explicitly set them
        if not any('--max-retries' in arg for arg in sys.argv):
            args.max_retries = 8
        if not any('--validation-retries' in arg for arg in sys.argv):
            args.validation_retries = 5
        if not any('--retry-delay' in arg for arg in sys.argv):
            args.retry_delay = 20
    
    # Output parameter information
    print(f"\n================================================================================")
    print(f"Multimodal Model Physics Property Analysis - Test Plan")
    print(f"================================================================================")
    print(f"Model version: {args.model_version}")
    print(f"Absolute test samples: {args.absolute_samples}")
    print(f"Relative comparison pairs: {args.relative_pairs}")
    print(f"Video frame size: {args.frame_width}x{args.frame_height}")
    print(f"Concatenation mode: {'Enabled' if args.concat_mode else 'Disabled'}")
    print(f"Frame index mode: {'Enabled' if args.frame_index else 'Disabled'}")
    print(f"ICL examples: {'Enabled' if args.icl_example else 'Disabled'}")
    print(f"Step-by-step guidance: {'Enabled' if args.follow_oracle_test else 'Disabled'}")
    print(f"API retry count: {args.max_retries}")
    print(f"Retry delay: {args.retry_delay} seconds")
    print(f"Validation retry count: {args.validation_retries}")
    print(f"Aggressive retry mode: {'Enabled' if args.aggressive_retry else 'Disabled'}")
    if args.properties:
        print(f"Test properties: {', '.join(args.properties)}")
    else:
        print(f"Test properties: All properties (friction, viscosity, elasticity)")
    print(f"Output directory: {args.output_dir}")
    print()
    
    # Set test parameters
    frame_size = (args.frame_width, args.frame_height)
    
    # Select maximum samples based on test type
    max_samples = max(args.absolute_samples, args.relative_pairs)
    
    # Display test configuration details
    print("Test configuration details:")
    print()
    
    total_absolute_tests = 0
    total_relative_tests = 0
    
    for version_name, version_config in analyzer.config.items():
        # If property filtering is specified, only show tests that meet the conditions
        if args.properties is not None and version_config['property'] not in args.properties:
            continue
        
        # Skip unnecessary test types based on sample count
        if version_config['type'] == 'absolute' and args.absolute_samples == 0:
            continue
        
        if version_config['type'] == 'relative' and args.relative_pairs == 0:
            continue
            
        # Filter out training sets, only show test sets
        test_sets = [k for k in version_config['data_dirs'].keys() if k != 'train']
        
        print(f"{version_name}:")
        print(f"  Property: {version_config['property']}")
        print(f"  Type: {version_config['type']}")
        print(f"  Test sets: {test_sets}")
        
        if version_config['type'] == 'absolute':
            print(f"  Samples per test set: {args.absolute_samples}")
            print(f"  Total tests: {len(test_sets)}")
            total_absolute_tests += len(test_sets)
        else:
            print(f"  Pairs per test set: {args.relative_pairs}")
            print(f"  Total tests: {len(test_sets)}")
            total_relative_tests += len(test_sets)
        print()
    
    print("Summary:")
    print(f"  Absolute tests: {total_absolute_tests}")
    print(f"  Relative comparison tests: {total_relative_tests}")
    print(f"  Total tests: {total_absolute_tests + total_relative_tests}")
    
    # Estimate time (approximately 30 seconds per test)
    estimated_minutes = (total_absolute_tests + total_relative_tests) * 0.5
    print(f"  Estimated time: {int(estimated_minutes//60)} hours {int(estimated_minutes%60)} minutes")
    print("================================================================================")
    print()
    
    if args.dry_run:
        print("💡 Dry run mode: Not executing actual tests")
        return 0
    
    # Ask for user confirmation
    print("Ready to start testing...")
    if not args.auto_confirm:
        response = input("Continue? (y/n, default n): ").lower().strip()
        if response not in ['y', 'yes']:
            print("Test cancelled")
            return 0
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    # Run tests
    print("Starting test execution...")
    print()
    
    # Use the new run_all_tests method with different parameters for different test types
    try:
        # Run all tests
        results = analyzer.run_all_tests(
            output_dir=args.output_dir,
            frame_size=frame_size,
            max_samples=max_samples,  # Maintain backward compatibility
            model_version=args.model_version,
            concat_mode=args.concat_mode,
            properties_filter=args.properties,
            frame_index=args.frame_index,
            icl_example=args.icl_example,
            follow_oracle_test=args.follow_oracle_test,
            absolute_samples=args.absolute_samples,
            relative_pairs=args.relative_pairs,
            max_retries=args.max_retries,
            retry_delay=args.retry_delay,
            validation_retries=args.validation_retries,
            aggressive_retry=args.aggressive_retry
        )
        
        print("\n" + "="*50)
        print("✅ All tests completed!")
        print(f"Results saved to: {args.output_dir}")
        print("="*50)
        
        return 0
        
    except Exception as e:
        print(f"✗ Execution failed: {e}")
        return 1


if __name__ == "__main__":
    sys.exit(main()) 