#!/usr/bin/env python3
"""
Unified evaluation script for the reasoning framework.

This script provides a unified interface for evaluating different reasoning approaches
(two-stage, adaptive) using the reasoning framework with evalscope integration.
"""

import sys
import os
import argparse
import json
import time
import glob
import pandas as pd
import pickle
from pathlib import Path
from typing import Dict, Any, List, Optional

current_file = Path(__file__).resolve()
framework_root = current_file.parent.parent.parent
sys.path.insert(0, str(framework_root))

# Add evalscope to path  
evalscope_root = framework_root / "evaluation" / "evalscope"
sys.path.insert(0, str(evalscope_root))

from reasoning_frameworks.core.config import (
    EvaluationConfig, 
    APIConfig, 
    ReproducibilityConfig,
    FrameworkConfig
)
from reasoning_frameworks.core.vlm_interface import VLMConfig
from reasoning_frameworks.core.reasoner_interface import ReasonerConfig
from reasoning_frameworks.utils.config_utils import load_config, find_config_file, merge_configs, load_logging_config, apply_logging_config_to_args
from reasoning_frameworks.utils.reproducibility import set_seed, configure_deterministic

import yaml
import requests
from requests.exceptions import RequestException, Timeout, ConnectionError


def create_vlm_only_model_config(args) -> Dict[str, Any]:
    """Create model configuration for VLM-only evaluation (no reasoning framework)."""
    
    # Set up reproducibility
    set_seed(args.seed)
    if args.deterministic:
        configure_deterministic()
    
    # Get VLM model name from args or config - fail if not provided
    vlm_model_name = args.vlm_model_name
    if not vlm_model_name:
        raise ValueError("VLM model name is required. Provide via --vlm_model_name or config file.")
    
    print(f"🔍 DEBUG: Creating VLM-only config for: {vlm_model_name}")
    print(f"🔍 DEBUG: VLM params: temperature={args.vlm_temperature}, top_p={args.vlm_top_p}, top_k={args.vlm_top_k}")
    
    model_config = {
        'type': vlm_model_name,  # The model name that matches your vLLM --served-model-name
        'name': 'CustomAPIModel',  # Use CustomAPIModel for API-based models
        'api_base': args.vlm_api_base,
        'key': 'EMPTY',
        'temperature': args.vlm_temperature or 0.0,
        'top_p': args.vlm_top_p or 0.001,
        'max_tokens': args.vlm_max_tokens or 8192,
        'timeout': int(args.vlm_timeout or 900)
    }
    
    # Handle top_k via extra_body if provided
    if args.vlm_top_k is not None and args.vlm_top_k > 0:
        model_config['extra_body'] = {'top_k': args.vlm_top_k}
    
    print(f"🔍 DEBUG: VLM-only model config created: {model_config}")
    return model_config


def create_framework_model_config(args) -> Dict[str, Any]:
    """Create model configuration for the reasoning framework."""
    
    # Handle VLM-only evaluation (no reasoning framework)
    if args.reasoning_approach == 'vlm_only':
        return create_vlm_only_model_config(args)
    
    # Set up environment variables for API keys
    if args.reasoner_api_key:
        if args.reasoner_type == 'deepseek':
            os.environ['DEEPSEEK_API_KEY'] = args.reasoner_api_key
        elif args.reasoner_type in ['openai', 'custom_api']:
            if args.reasoner_api_key != 'EMPTY':
                os.environ['OPENAI_API_KEY'] = args.reasoner_api_key
                os.environ['CUSTOM_REASONER_API_KEY'] = args.reasoner_api_key

    # Set up reproducibility
    set_seed(args.seed)
    if args.deterministic:
        configure_deterministic()
    
    # Get VLM model name from args or config - fail if not provided
    vlm_model_name = args.vlm_model_name
    if not vlm_model_name:
        raise ValueError("VLM model name is required. Provide via --vlm_model_name or config file.")
    
    # Create VLM configuration
    print(f"🔍 DEBUG: Creating VLMConfig with: temperature={args.vlm_temperature}, top_p={args.vlm_top_p}, top_k={args.vlm_top_k}")
    vlm_config = VLMConfig(
        model_name=vlm_model_name,
        model_type="vllm",
        api_base=args.vlm_api_base,
        api_key="EMPTY",
        timeout=int(args.vlm_timeout),
        max_tokens=args.vlm_max_tokens,
        temperature=args.vlm_temperature,
        top_p=args.vlm_top_p,
        top_k=args.vlm_top_k
    )
    print(f"🔍 DEBUG: VLMConfig created: {vlm_config}")
    
    # Get reasoner model name - fail if not provided
    reasoner_model_name = args.reasoner_model
    if not reasoner_model_name:
        raise ValueError("Reasoner model name is required. Provide via --reasoner_model or config file.")

    # Create reasoner configuration based on type
    if args.reasoner_type == 'deepseek':
        reasoner_config = ReasonerConfig(
            model_name=reasoner_model_name,
            model_type='deepseek',
            api_base=args.reasoner_api_base,
            api_key=args.reasoner_api_key,
            timeout=int(args.reasoner_timeout),
            max_tokens=args.reasoner_max_tokens,
            temperature=args.reasoner_temperature,
            top_p=args.reasoner_top_p,
            top_k=args.reasoner_top_k
        )
    elif args.reasoner_type == 'openai':
        reasoner_config = ReasonerConfig(
            model_name=reasoner_model_name,
            model_type='openai',
            api_base=args.reasoner_api_base,
            api_key=args.reasoner_api_key,
            timeout=int(args.reasoner_timeout),
            max_tokens=args.reasoner_max_tokens,
            temperature=args.reasoner_temperature,
            top_p=args.reasoner_top_p,
            top_k=args.reasoner_top_k
        )
    elif args.reasoner_type == 'custom_api':
        reasoner_config = ReasonerConfig(
            model_name=reasoner_model_name,
            model_type='api',  # Use generic API interface for custom APIs
            api_base=args.reasoner_api_base,
            api_key=args.reasoner_api_key,
            timeout=int(args.reasoner_timeout),
            max_tokens=args.reasoner_max_tokens,
            temperature=args.reasoner_temperature,
            top_p=args.reasoner_top_p,
            top_k=args.reasoner_top_k
        )
    else:
        raise ValueError(f"Unknown reasoner type: {args.reasoner_type}")
    
    # Common logging configuration
    logging_config = {
        'html_report_dir_base': args.html_report_dir if args.reasoning_approach == 'two_stage' else None,
        'html_report_dir': args.html_report_dir if args.reasoning_approach == 'adaptive' else None,
        'enable_html_reports': args.enable_html_reports,
        'debug_data_dir': args.debug_data_dir,
        'seed': args.seed,
        'deterministic': args.deterministic,
        # Unified logging configuration
        'enable_mlflow': args.enable_mlflow,
        'enable_wandb': args.enable_wandb,
        'mlflow_tracking_uri': args.mlflow_tracking_uri,
        'wandb_project': args.wandb_project,
        'wandb_entity': args.wandb_entity,
    }
    
    # Create model configuration based on reasoning type
    if args.reasoning_approach == 'two_stage':
        model_config = {
            'name': 'TwoStageFrameworkModel',
            'vlm_config': vlm_config.to_dict(),
            'reasoner_config': reasoner_config.to_dict(),
            # Confidence estimation experiment flags
            'enable_vlm_confidence': getattr(args, 'enable_vlm_confidence', False),
            'use_confidence_in_reasoner': getattr(args, 'use_confidence_in_reasoner', False),
            'use_logprobs_confidence': getattr(args, 'use_logprobs_confidence', False),
            **logging_config
        }
    
    elif args.reasoning_approach == 'three_stage':
        gen_param_block = {
            'temperature': args.reasoner_temperature,
            'top_p': args.reasoner_top_p,
            'top_k': args.reasoner_top_k,
            'max_tokens': args.reasoner_max_tokens,
        }

        model_config = {
            'name': 'ThreeStageVLMFrameworkModel',
            'vlm_config': vlm_config.to_dict(),
            'reasoner_config': reasoner_config.to_dict(),
            'prompt_template_name': getattr(args, 'prompt_template_name', 'three_stage_math_v1'),
            # Confidence estimation experiment flags
            'enable_vlm_confidence': getattr(args, 'enable_vlm_confidence', False),
            'use_confidence_in_reasoner': getattr(args, 'use_confidence_in_reasoner', False),
            'use_logprobs_confidence': getattr(args, 'use_logprobs_confidence', False),
            # Clarification experiment flags
            'deny_clarifications': getattr(args, 'deny_clarifications', False),
            'track_clarification_requests': getattr(args, 'track_clarification_requests', False),
            **logging_config,
            **gen_param_block
        }
    
    elif args.reasoning_approach == 'adaptive':
        gen_param_block = {
            'temperature': args.reasoner_temperature,
            'top_p': args.reasoner_top_p,
            'top_k': args.reasoner_top_k,
            'max_tokens': args.reasoner_max_tokens,
        }

        model_config = {
            'name': 'AdaptiveVLMFrameworkModel',
            'vlm_config': vlm_config.to_dict(),
            'reasoner_config': reasoner_config.to_dict(),
            'max_iterations': args.max_iterations,
            'enable_verification': args.enable_verification,
            'prompt_template_name': getattr(args, 'prompt_template_name', 'adaptive_v1'),
            # Confidence estimation experiment flags
            'enable_vlm_confidence': getattr(args, 'enable_vlm_confidence', False),
            'use_confidence_in_reasoner': getattr(args, 'use_confidence_in_reasoner', False),
            'use_logprobs_confidence': getattr(args, 'use_logprobs_confidence', False),
            **logging_config,
            **gen_param_block
        }
    
    else:
        raise ValueError(f"Unknown reasoning approach: {args.reasoning_approach}")
    
    return model_config


def build_evalscope_task_config(args) -> Dict[str, Any]:
    """Build evalscope task configuration."""
    
    model_config = create_framework_model_config(args)
    
    # VLMEvalKit configuration
    vlmeval_config = {
        'model': [model_config],
        'data': args.datasets,
        'nproc': args.nproc,
        'mode': 'all',
        'judge': args.judge,
        'work_dir': args.work_dir,
        'reuse': True,  # Reuse existing evaluation results instead of deleting them
    }

    # Set limit based on dry_run or explicit limit
    if args.dry_run:
        vlmeval_config['limit'] = args.limit if args.limit else 10
    else:
        vlmeval_config['limit'] = args.limit if args.limit else 1000000000

    # Add specific indices if provided
    if args.indices:
        vlmeval_config['indices'] = args.indices
        print(f"🎯 Evaluating specific indices: {args.indices}")
    
    # Add shuffle parameters for representative sampling
    if getattr(args, 'shuffle', False):
        vlmeval_config['shuffle'] = args.shuffle
        if getattr(args, 'seed', None) is not None:
            vlmeval_config['seed'] = args.seed
        print(f"🔀 Shuffle enabled for representative sampling with seed={getattr(args, 'seed', 42)}")
    
    # Task configuration
    task_config = {
        'eval_backend': 'VLMEvalKit',
        'eval_config': vlmeval_config,
    }
    
    if args.use_cache:
        task_config['use_cache'] = args.use_cache

    return task_config


def export_debug_data(work_dir: str, debug_data_dir: str, model_name: str, 
                     only_recent_session: bool = True):
    """Export debug data from the evaluation run."""
    print(f"\n>> Exporting {model_name} debug data...")
    
    try:
        # Find debug data files
        debug_pkl_files = []
        
        if debug_data_dir and os.path.exists(debug_data_dir):
            pattern = os.path.join(debug_data_dir, f'{model_name}_*_debug_data_*.pkl')
            all_debug_files = glob.glob(pattern)
            
            if only_recent_session and all_debug_files:
                # Get most recent session
                session_timestamps = []
                for f in all_debug_files:
                    filename = os.path.basename(f)
                    if filename.startswith(f'{model_name}_'):
                        parts = filename.split('_')
                        if len(parts) >= 2:
                            try:
                                session_ts = int(parts[1])
                                session_timestamps.append(session_ts)
                            except ValueError:
                                continue
                
                if session_timestamps:
                    most_recent_session = max(session_timestamps)
                    debug_pkl_files = [f for f in all_debug_files 
                                     if f'{model_name}_{most_recent_session}_' in os.path.basename(f)]
                else:
                    debug_pkl_files = all_debug_files
            else:
                debug_pkl_files = all_debug_files
        
        # Find VLMEvalKit result files
        model_dir = os.path.join(work_dir, model_name)
        xlsx_files = []
        
        if os.path.exists(model_dir):
            for dataset in args.datasets:
                pattern = os.path.join(model_dir, f'{model_name}_{dataset}_*.xlsx')
                dataset_xlsx_files = glob.glob(pattern)
                judge_files = [f for f in dataset_xlsx_files 
                             if 'gpt' in os.path.basename(f) or 'judge' in os.path.basename(f)]
                xlsx_files.extend(judge_files or dataset_xlsx_files)
        
        if not xlsx_files and not debug_pkl_files:
            print("No data files found to export.")
            return
        
        # Process and merge data
        all_data = {}
        
        # Process VLMEvalKit xlsx files
        for xlsx_file in xlsx_files:
            try:
                file_name = os.path.basename(xlsx_file)
                # Extract dataset name from filename
                if file_name.startswith(f'{model_name}_'):
                    parts = file_name.split('_')
                    judge_start_idx = len(parts)
                    for i, part in enumerate(parts):
                        if i > 1 and ('gpt' in part.lower() or part in ['exact', 'matching']):
                            judge_start_idx = i
                            break
                    dataset_parts = parts[1:judge_start_idx]
                    dataset = '_'.join(dataset_parts)
                else:
                    dataset = file_name.split('_')[0]
                
                df = pd.read_excel(xlsx_file)
                
                if dataset not in all_data:
                    all_data[dataset] = {}
                
                for _, row in df.iterrows():
                    if 'index' not in df.columns:
                        continue
                    
                    index = row['index']
                    if index not in all_data[dataset]:
                        all_data[dataset][index] = {}
                    
                    for col in df.columns:
                        all_data[dataset][index][col] = row.get(col, "")
                
            except Exception as e:
                print(f"Error processing {xlsx_file}: {e}")
                continue
        
        # Process debug files
        for debug_file in debug_pkl_files:
            try:
                with open(debug_file, 'rb') as f:
                    debug_data = pickle.load(f)
                
                for index, item_data in debug_data.items():
                    if isinstance(item_data, dict) and "dataset" in item_data:
                        dataset = item_data["dataset"]
                        
                        if dataset not in all_data:
                            all_data[dataset] = {}
                        if index not in all_data[dataset]:
                            all_data[dataset][index] = {}
                        
                        for key, value in item_data.items():
                            debug_key = f"debug_{key}" if key not in ['index', 'dataset'] else key
                            all_data[dataset][index][debug_key] = value
                
            except Exception as e:
                print(f"Error processing debug file {debug_file}: {e}")
                continue
        
        # Save merged data
        if all_data:
            timestamp = time.strftime("%Y%m%d_%H%M%S")
            output_dir = model_dir if os.path.exists(model_dir) else work_dir
            
            for dataset, dataset_data in all_data.items():
                if not dataset_data:
                    continue
                
                dataset_flattened = []
                for index, item_data in dataset_data.items():
                    item = {'dataset': dataset, 'index': index}
                    item.update(item_data)
                    dataset_flattened.append(item)
                
                # Save files
                csv_filename = f"{model_name}_{dataset}_merged_data_{timestamp}.csv"
                csv_latest_filename = f"{model_name}_{dataset}_merged_data_latest.csv"
                
                df_dataset = pd.DataFrame(dataset_flattened)
                df_dataset.to_csv(os.path.join(output_dir, csv_filename), index=False)
                df_dataset.to_csv(os.path.join(output_dir, csv_latest_filename), index=False)
                
                try:
                    excel_filename = f"{model_name}_{dataset}_merged_data_{timestamp}.xlsx"
                    excel_latest_filename = f"{model_name}_{dataset}_merged_data_latest.xlsx"
                    df_dataset.to_excel(os.path.join(output_dir, excel_filename), index=False)
                    df_dataset.to_excel(os.path.join(output_dir, excel_latest_filename), index=False)
                except Exception as e:
                    print(f"Could not save Excel file: {e}")
                
                print(f"Exported data for {dataset}: {len(dataset_flattened)} samples")
        
    except Exception as e:
        print(f"Error exporting debug data: {e}")


def check_api_server(api_base: str, server_name: str, timeout: int = 10) -> bool:
    """
    Check if an API server is reachable.
    
    Args:
        api_base: API base URL
        server_name: Human-readable name for the server
        timeout: Timeout in seconds
        
    Returns:
        True if server is reachable, False otherwise
    """
    try:
        # Try to reach the server with a simple GET request
        # Most API servers should respond to the base URL or /health
        test_urls = [
            api_base.rstrip('/'),
            f"{api_base.rstrip('/')}/health",
            f"{api_base.rstrip('/')}/v1/models"  # Common for OpenAI-compatible APIs
        ]
        
        for url in test_urls:
            try:
                response = requests.get(url, timeout=timeout)
                # Accept any response that isn't a connection error
                print(f"✅ {server_name} server reachable at {api_base} (status: {response.status_code})")
                return True
            except (ConnectionError, Timeout):
                continue
            except RequestException as e:
                # Other HTTP errors (like 404, 401) might be okay - server is responding
                if hasattr(e, 'response') and e.response is not None:
                    print(f"✅ {server_name} server reachable at {api_base} (status: {e.response.status_code})")
                    return True
                continue
        
        print(f"❌ {server_name} server not reachable at {api_base}")
        return False
        
    except Exception as e:
        print(f"❌ Error checking {server_name} server at {api_base}: {e}")
        return False


def check_mlflow_setup(tracking_uri: str) -> bool:
    """
    Check if MLflow is properly configured and accessible.
    
    Args:
        tracking_uri: MLflow tracking URI
        
    Returns:
        True if MLflow is accessible, False otherwise
    """
    try:
        import mlflow
        
        # Set tracking URI and test connection
        original_uri = mlflow.get_tracking_uri()
        mlflow.set_tracking_uri(tracking_uri)
        
        # Try to get or create an experiment
        try:
            client = mlflow.tracking.MlflowClient()
            experiments = client.search_experiments()
            print(f"✅ MLflow accessible at {tracking_uri} ({len(experiments)} experiments found)")
            return True
        except Exception as e:
            print(f"❌ MLflow not accessible at {tracking_uri}: {e}")
            return False
        finally:
            # Restore original URI
            mlflow.set_tracking_uri(original_uri)
            
    except ImportError:
        print("❌ MLflow not installed")
        return False
    except Exception as e:
        print(f"❌ Error checking MLflow setup: {e}")
        return False


def check_wandb_setup(project: str = None, entity: str = None) -> bool:
    """
    Check if Wandb is properly configured.
    
    Args:
        project: Wandb project name
        entity: Wandb entity name
        
    Returns:
        True if Wandb is accessible, False otherwise
    """
    try:
        import wandb
        
        # Check if API key is set
        api_key = wandb.api.api_key
        if not api_key:
            print("❌ Wandb API key not set. Run 'wandb login' or set WANDB_API_KEY environment variable")
            return False
        
        # Try to initialize without actually starting a run
        try:
            # Test connection by trying to access the API
            api = wandb.Api()
            user = api.viewer
            print(f"✅ Wandb accessible (logged in as {user.username or 'unknown'})")
            return True
        except Exception as e:
            print(f"❌ Wandb API not accessible: {e}")
            return False
            
    except ImportError:
        print("❌ Wandb not installed")
        return False
    except Exception as e:
        print(f"❌ Error checking Wandb setup: {e}")
        return False


def run_sanity_checks(args) -> bool:
    """
    Run all sanity checks and return True if all pass.
    
    Args:
        args: Parsed command line arguments
        
    Returns:
        True if all checks pass, False otherwise
    """
    print("🔍 Running sanity checks...")
    all_checks_passed = True
    
    # Check API servers
    print("\n📡 Checking API servers...")
    
    # Check VLM server
    if not check_api_server(args.vlm_api_base, "VLM", timeout=15):
        all_checks_passed = False
    
    # Check reasoner server
    if not check_api_server(args.reasoner_api_base, "Reasoner", timeout=15):
        all_checks_passed = False
    
    # Check logging services if enabled
    if args.enable_mlflow or args.enable_wandb:
        print("\n📊 Checking logging services...")
        
        if args.enable_mlflow:
            if not check_mlflow_setup(args.mlflow_tracking_uri):
                print("⚠️  MLflow check failed - consider disabling with --enable_mlflow=False")
                # Don't fail the entire run for logging issues, just warn
                # all_checks_passed = False
        
        if args.enable_wandb:
            if not check_wandb_setup(args.wandb_project, args.wandb_entity):
                print("⚠️  Wandb check failed - consider disabling with --enable_wandb=False")
                # Don't fail the entire run for logging issues, just warn
                # all_checks_passed = False
    
    # Check output directories
    print("\n📁 Checking output directories...")
    try:
        # Create output directories if they don't exist
        os.makedirs(args.work_dir, exist_ok=True)
        os.makedirs(args.debug_data_dir, exist_ok=True)
        if args.enable_html_reports:
            os.makedirs(args.html_report_dir, exist_ok=True)
        print("✅ Output directories accessible")
    except Exception as e:
        print(f"❌ Error with output directories: {e}")
        all_checks_passed = False
    
    if all_checks_passed:
        print("✅ All sanity checks passed!")
    else:
        print("❌ Some sanity checks failed!")
    
    print("-" * 60)
    return all_checks_passed


def run_evaluation(args):
    """Run the evaluation using evalscope."""
    
    # Build task configuration
    task_config = build_evalscope_task_config(args)
    
    print("=" * 60)
    if args.reasoning_approach == 'vlm_only':
        print(f"🚀 Running VLM-Only Direct Evaluation")
    elif args.reasoning_approach == 'three_stage':
        print(f"🚀 Running Three-Stage Reasoning Evaluation")
    else:
        print(f"🚀 Running {args.reasoning_approach.title()} Reasoning Evaluation")
    print("=" * 60)
    print(f"VLM: {args.vlm_model_name} at {args.vlm_api_base}")
    print(f"    ├─ Max tokens: {args.vlm_max_tokens}")
    print(f"    └─ Params: temp={args.vlm_temperature}, top_p={args.vlm_top_p}, top_k={args.vlm_top_k}")
    
    if args.reasoning_approach != 'vlm_only':
        print(f"Reasoner: {args.reasoner_model} ({args.reasoner_type}) at {args.reasoner_api_base}") 
        print(f"    ├─ Max tokens: {args.reasoner_max_tokens}")
        print(f"    └─ Params: temp={args.reasoner_temperature}, top_p={args.reasoner_top_p}, top_k={args.reasoner_top_k}")
    else:
        print(f"Mode: Direct VLM evaluation (no reasoning framework)")
    print(f"Datasets: {args.datasets}")
    print(f"Judge: {args.judge}")
    print(f"Limit: {args.limit or 'unlimited'}")
    
    # Shuffle information
    if getattr(args, 'shuffle', False):
        print(f"🔀 Sampling: shuffle enabled (seed={getattr(args, 'seed', 42)}) for representative sampling")
    elif args.limit:
        print(f"⚠️  Sampling: sequential (no shuffle) - may introduce bias for ordered datasets")
    
    print(f"Reproducibility: seed={args.seed}, deterministic={'enabled' if args.deterministic else 'disabled'}")
    
    # Unified logging information
    logging_features = []
    if args.enable_html_reports:
        logging_features.append("HTML")
    if args.enable_mlflow:
        logging_features.append("MLflow")
    if args.enable_wandb:
        logging_features.append("Wandb")
    if logging_features:
        print(f"📊 Unified logging: {', '.join(logging_features)}")
    
    if args.reasoning_approach == 'adaptive':
        print(f"Adaptive: max_iter={args.max_iterations}, verification={'enabled' if args.enable_verification else 'disabled'}")
    elif args.reasoning_approach == 'three_stage':
        clarification_mode = "denied" if getattr(args, 'deny_clarifications', False) else "allowed"
        track_mode = "tracked" if getattr(args, 'track_clarification_requests', False) else "not tracked"
        print(f"Three-stage: clarifications {clarification_mode}, requests {track_mode}")
    
    print("=" * 60)
    
    # Log experiment parameters if any logging is enabled
    if args.enable_mlflow or args.enable_wandb:
        # Create a temporary unified logger to log experiment parameters
        try:
            from reasoning_frameworks.utils.unified_logger import UnifiedReasoningLogger
            
            # Create experiment logger for overall experiment tracking
            experiment_logger = UnifiedReasoningLogger(
                experiment_name=f"{args.reasoning_approach}_evaluation",
                output_dir=args.html_report_dir or "./outputs",
                enable_mlflow=args.enable_mlflow,
                enable_wandb=args.enable_wandb,
                mlflow_tracking_uri=args.mlflow_tracking_uri,
                wandb_project=args.wandb_project,
                wandb_entity=args.wandb_entity
            )
            
            # Log experiment parameters
            experiment_params = {
                'reasoning_approach': args.reasoning_approach,
                'datasets': ','.join(args.datasets),
                'vlm_model': args.vlm_model_name,
                'vlm_api_base': args.vlm_api_base,
                'reasoner_type': args.reasoner_type,
                'reasoner_model': args.reasoner_model,
                'reasoner_api_base': args.reasoner_api_base,
                'judge': args.judge,
                'nproc': args.nproc,
                'limit': args.limit or 'unlimited',
                'seed': args.seed,
                'deterministic': args.deterministic,
                'vlm_max_tokens': args.vlm_max_tokens,
                'vlm_temperature': args.vlm_temperature,
                'vlm_top_p': args.vlm_top_p,
                'vlm_top_k': args.vlm_top_k,
                'reasoner_max_tokens': args.reasoner_max_tokens,
                'reasoner_temperature': args.reasoner_temperature,
                'reasoner_top_p': args.reasoner_top_p,
                'reasoner_top_k': args.reasoner_top_k,
            }
            
            if args.reasoning_approach == 'adaptive':
                experiment_params.update({
                    'max_iterations': args.max_iterations,
                    'enable_verification': args.enable_verification
                })
            
            experiment_logger.log_experiment_params(experiment_params)
            
            print("📊 Logged experiment parameters to tracking systems")
            
        except Exception as e:
            print(f"Warning: Could not set up experiment tracking: {e}")
    
    # Import and run evalscope
    from evalscope.run import run_task
    from evalscope.summarizer import Summarizer
    
    print('>> Running evaluation...')
    run_task(task_cfg=task_config)
    
    print('>> Getting evaluation report...')
    report_list = Summarizer.get_report_from_cfg(task_config)
    print(f'\n>> Report list: {report_list}')
    
    # Export debug data
    model_name = task_config['eval_config']['model'][0]['name']
    export_debug_data(
        work_dir=args.work_dir,
        debug_data_dir=args.debug_data_dir,
        model_name=model_name,
        only_recent_session=not args.load_all_debug_sessions
    )
    
    # Finish experiment logging if enabled
    if args.enable_mlflow or args.enable_wandb:
        try:
            # Log completion metrics
            completion_metrics = {
                'evaluation_completed': 1,
                'total_datasets': len(args.datasets)
            }
            experiment_logger.log_metrics(completion_metrics)
            experiment_logger.finish()
            print("📊 Finished experiment tracking")
        except Exception as e:
            print(f"Warning: Could not finish experiment tracking: {e}")
    
    print("✅ Evaluation completed!")


def main():
    """Main evaluation function."""
    parser = argparse.ArgumentParser(
        description="Unified evaluation script for reasoning framework"
    )
    
    # Core approach selection
    parser.add_argument('--reasoning_approach', 
                       choices=['two_stage', 'three_stage', 'adaptive', 'vlm_only'],
                       help='Reasoning approach to use')
    
    # VLM configuration  
    parser.add_argument('--vlm_api_base',
                       help='VLM API base URL (required)')
    parser.add_argument('--vlm_model_name',
                       help='VLM model name (required unless specified in config file)')
    parser.add_argument('--vlm_timeout', type=float, default=900.0,
                       help='VLM API timeout in seconds')
    
    # Reasoner configuration
    parser.add_argument('--reasoner_type', 
                       choices=['deepseek', 'openai', 'custom_api'],
                       help='Reasoner type (required)')
    parser.add_argument('--reasoner_api_base',
                       help='Reasoner API base URL')
    parser.add_argument('--reasoner_api_key',
                       help='Reasoner API key')
    parser.add_argument('--reasoner_model',
                       help='Reasoner model name')
    parser.add_argument('--reasoner_timeout', type=float, default=900.0,
                       help='Reasoner API timeout in seconds')
    
    # Generation parameters (VLM and Reasoner)
    # --- VLM sampling params ---
    parser.add_argument('--vlm_temperature', type=float, default=None,
                       help='Sampling temperature for VLM (overrides config)')
    parser.add_argument('--vlm_top_p', type=float, default=None,
                       help='Top-p for VLM (overrides config)')
    parser.add_argument('--vlm_top_k', type=int, default=None,
                       help='Top-k for VLM (overrides config)')

    # --- Reasoner / generic params ---
    parser.add_argument('--vlm_max_tokens', type=int, default=None,
                       help='Max tokens for VLM responses (omit to let model decide)')
    parser.add_argument('--reasoner_max_tokens', type=int, default=100000,
                       help='Max tokens for reasoner responses')
    parser.add_argument('--temperature', type=float, default=0.6,
                       help='Default sampling temperature for reasoner')
    parser.add_argument('--top_p', type=float, default=0.95,
                       help='Default top-p for reasoner')
    parser.add_argument('--top_k', type=int, default=None,
                       help='Default top-k for reasoner')
    
    # Adaptive-specific parameters
    parser.add_argument('--max_iterations', type=int, default=7,
                       help='Maximum iterations for adaptive reasoning')
    parser.add_argument('--enable_verification', action='store_true',
                       help='Enable verification for adaptive reasoning')
    
    # Confidence estimation experiment parameters
    parser.add_argument('--enable_vlm_confidence', action='store_true',
                       help='Enable VLM confidence estimation (experimental)')
    parser.add_argument('--use_confidence_in_reasoner', action='store_true',
                       help='Use VLM confidence in reasoner prompts (experimental)')
    parser.add_argument('--use_logprobs_confidence', action='store_true',
                       help='Use logprobs-based confidence estimation (experimental)')
    
    # Clarification experiment parameters (three-stage specific)
    parser.add_argument('--deny_clarifications', action='store_true',
                       help='Deny clarification requests in three-stage reasoning (experimental)')
    parser.add_argument('--track_clarification_requests', action='store_true',
                       help='Track whether clarification requests are made (experimental)')
    
    # Evaluation configuration
    parser.add_argument('--datasets', nargs='+', 
                       default=['MathVista_MINI', 'MathVerse_MINI'],
                       help='Datasets to evaluate on')
    parser.add_argument('--judge', 
                       choices=['exact_matching', 'gpt-4-0125', 'gpt-4-turbo', 
                               'gpt-4o-mini', 'gpt-4.1-mini-2025-04-14', 'gpt-5-nano-2025-08-07'],
                       default='exact_matching',
                       help='Judge model for evaluation')
    parser.add_argument('--nproc', type=int, default=32,
                       help='Number of processes')
    parser.add_argument('--work_dir', default='outputs/',
                       help='Work directory for outputs')
    parser.add_argument('--use_cache', 
                       help='Path to cached results to reuse')
    
    # Sampling configuration
    parser.add_argument('--dry_run', action='store_true',
                       help='Run on small subset for testing')
    parser.add_argument('--limit', type=int,
                       help='Limit number of samples to evaluate')
    
    # Index-specific evaluation
    parser.add_argument('--indices', type=int, nargs='+',
                       help='Specific indices to evaluate (e.g., --indices 3 10 13 15)')
    
    # Shuffle parameters for representative sampling
    parser.add_argument('--shuffle', action='store_true',
                       help='Enable shuffling for representative sampling when using --limit')
    
    # Output configuration
    parser.add_argument('--debug_data_dir', 
                       default='/scratch/<ANONYMIZED>/framework_debug_data',
                       help='Directory for debug data')
    parser.add_argument('--html_report_dir',
                       default='/scratch/<ANONYMIZED>/framework_html_reports',
                       help='Directory for HTML reports')
    parser.add_argument('--enable_html_reports', action='store_true',
                       help='Enable HTML report generation')
    parser.add_argument('--load_all_debug_sessions', action='store_true',
                       help='Load all debug sessions, not just recent')
    
    # Unified logging configuration
    parser.add_argument('--enable_mlflow', action='store_true',
                       help='Enable MLflow experiment tracking')
    parser.add_argument('--mlflow_tracking_uri',
                       default='./mlruns',
                       help='MLflow tracking URI')
    parser.add_argument('--enable_wandb', action='store_true',
                       help='Enable Weights & Biases tracking')
    parser.add_argument('--wandb_project',
                       help='Wandb project name (defaults to experiment name)')
    parser.add_argument('--wandb_entity',
                       help='Wandb entity name')
    
    # Reproducibility
    parser.add_argument('--seed', type=int, default=42,
                       help='Random seed')
    parser.add_argument('--deterministic', action='store_true',
                       help='Enable deterministic mode')
    
    # Sanity checks
    parser.add_argument('--skip_sanity_checks', action='store_true',
                       help='Skip API server and logging service sanity checks')
    
    # Config file support
    parser.add_argument('--config_file',
                       help='Path to YAML config file to load settings from')
    parser.add_argument('--config_name',
                       help='Name of evaluation config to load (e.g., "two_stage_legacy", "adaptive_legacy")')
    parser.add_argument('--logging_config',
                       default='evaluation_logging',
                       help='Name of logging configuration to use (default: evaluation_logging)')
    
    args = parser.parse_args()
    
    # Load logging configuration first (before other configs)
    try:
        logging_config = load_logging_config(args.logging_config)
        apply_logging_config_to_args(args, logging_config)
        print(f"📝 Loaded logging configuration: {args.logging_config}")
    except Exception as e:
        print(f"Warning: Could not load logging configuration '{args.logging_config}': {e}")
        print("Using command line arguments and defaults.")
    
    # Load config by name if provided
    if args.config_name:
        try:
            config_file_path = find_config_file(args.config_name, "evaluation")
            args.config_file = str(config_file_path)
            print(f"📁 Using evaluation config: {args.config_name}")
        except FileNotFoundError as e:
            print(f"Error: {e}")
            sys.exit(1)
    
    # Load config file if provided
    if args.config_file:
        config_data = load_config(args.config_file)
        
        # Handle nested evaluation config structure
        if 'evaluation' in config_data:
            eval_config = config_data['evaluation']
            # Update args with evaluation config values (command line takes precedence)
            for key, value in eval_config.items():
                arg_value = getattr(args, key, None)
                # Only update if the argument wasn't explicitly set on command line
                if hasattr(args, key) and (arg_value is None or 
                    (key in ['datasets', 'judge'] and arg_value == parser.get_default(key))):
                    setattr(args, key, value)
        
        # Handle framework config references and overrides
        if 'framework' in config_data:
            framework_config = config_data['framework']
            
            # Handle prompt template reference
            if 'prompt_template_ref' in framework_config:
                args.prompt_template_name = framework_config['prompt_template_ref']
                print(f"📄 Using prompt template: {args.prompt_template_name}")
            
            # Load VLM config with overrides
            vlm_config_data = {}
            if 'vlm_config_ref' in framework_config:
                try:
                    framework_root = Path(__file__).parent.parent
                    ref_path = framework_config['vlm_config_ref'].replace('.', '/') + '.yaml'
                    vlm_config_path = framework_root / "configs" / ref_path
                    if vlm_config_path.exists():
                        vlm_config_data = load_config(vlm_config_path)
                        print(f"📄 Loaded VLM config: {framework_config['vlm_config_ref']}")
                except Exception as e:
                    print(f"Warning: Could not load VLM config reference: {e}")
            
            # Apply VLM overrides
            if 'vlm_overrides' in framework_config:
                vlm_config_data.update(framework_config['vlm_overrides'])
                print(f"⚙️  Applied VLM overrides: {list(framework_config['vlm_overrides'].keys())}")
                print(f"🔍 DEBUG: VLM config after overrides: temperature={vlm_config_data.get('temperature')}, top_p={vlm_config_data.get('top_p')}, top_k={vlm_config_data.get('top_k')}")

            # --- NEW: ensure override values are *always* propagated to per-VLM args ---
            if 'temperature' in vlm_config_data and args.vlm_temperature == parser.get_default('vlm_temperature'):
                args.vlm_temperature = vlm_config_data['temperature']
            elif 'temperature' in vlm_config_data:
                print("⚠️  CLI overrides config temperature – using CLI value")
            if 'top_p' in vlm_config_data and args.vlm_top_p == parser.get_default('vlm_top_p'):
                args.vlm_top_p = vlm_config_data['top_p']
            elif 'top_p' in vlm_config_data:
                print("⚠️  CLI overrides config top_p – using CLI value")
            if 'top_k' in vlm_config_data and args.vlm_top_k == parser.get_default('vlm_top_k'):
                args.vlm_top_k = vlm_config_data['top_k']
            elif 'top_k' in vlm_config_data:
                print("⚠️  CLI overrides config top_k – using CLI value")
            if 'max_tokens' in vlm_config_data and args.vlm_max_tokens is None:
                args.vlm_max_tokens = vlm_config_data['max_tokens']
            elif 'max_tokens' in vlm_config_data:
                print("⚠️  CLI overrides config vlm_max_tokens – using CLI value")
            # ------------------------------------------------------------------------
            
            # Set VLM parameters from config
            if not args.vlm_model_name and 'model_name' in vlm_config_data:
                args.vlm_model_name = vlm_config_data['model_name']
            if not args.vlm_api_base and 'api_base' in vlm_config_data:
                args.vlm_api_base = vlm_config_data['api_base']
            
            # Load reasoner config with overrides
            reasoner_config_data = {}
            if 'reasoner_config_ref' in framework_config:
                try:
                    framework_root = Path(__file__).parent.parent
                    ref_path = framework_config['reasoner_config_ref'].replace('.', '/') + '.yaml'
                    reasoner_config_path = framework_root / "configs" / ref_path
                    if reasoner_config_path.exists():
                        reasoner_config_data = load_config(reasoner_config_path)
                        print(f"📄 Loaded reasoner config: {framework_config['reasoner_config_ref']}")
                except Exception as e:
                    print(f"Warning: Could not load reasoner config reference: {e}")
            
            # Apply reasoner overrides
            if 'reasoner_overrides' in framework_config:
                reasoner_config_data.update(framework_config['reasoner_overrides'])
                print(f"⚙️  Applied reasoner overrides: {list(framework_config['reasoner_overrides'].keys())}")
            
            # Set reasoner parameters from config
            if not args.reasoner_model and 'model_name' in reasoner_config_data:
                args.reasoner_model = reasoner_config_data['model_name']
            if not args.reasoner_api_base and 'api_base' in reasoner_config_data:
                args.reasoner_api_base = reasoner_config_data['api_base']
            if not args.reasoner_api_key and 'api_key' in reasoner_config_data:
                args.reasoner_api_key = reasoner_config_data['api_key']
            if not args.reasoner_type and 'model_type' in reasoner_config_data:
                args.reasoner_type = reasoner_config_data['model_type']
            if not hasattr(args, 'reasoner_max_tokens') and 'max_tokens' in reasoner_config_data:
                args.reasoner_max_tokens = reasoner_config_data['max_tokens']
            
            # Set reasoner generation parameters if not specified
            if args.temperature == parser.get_default('temperature') and 'temperature' in reasoner_config_data:
                args.reasoner_temperature = reasoner_config_data['temperature']
            else:
                args.reasoner_temperature = args.temperature
            if args.top_p == parser.get_default('top_p') and 'top_p' in reasoner_config_data:
                args.reasoner_top_p = reasoner_config_data['top_p']
            else:
                args.reasoner_top_p = args.top_p
            if args.top_k == parser.get_default('top_k') and 'top_k' in reasoner_config_data:
                args.reasoner_top_k = reasoner_config_data['top_k']
            else:
                args.reasoner_top_k = args.top_k

            # ------------------------------------------------------------------
            # 🛠️  Sanity-fix for top_k: the legacy config uses -1 to mean "no limit".
            # Forwarding a negative value to vLLM/OpenAI servers degrades output.
            # Treat any value <= 0 as *unset*.
            # ------------------------------------------------------------------
            if args.reasoner_top_k is not None and args.reasoner_top_k <= 0:
                args.reasoner_top_k = None
        
        # Handle adaptive-specific parameters
        if 'adaptive' in config_data:
            adaptive_config = config_data['adaptive']
            for key, value in adaptive_config.items():
                if hasattr(args, key) and getattr(args, key) == parser.get_default(key):
                    setattr(args, key, value)
        
        # Handle two-stage specific parameters
        if 'two_stage' in config_data:
            two_stage_config = config_data['two_stage']
            print(f"📄 Loading two_stage config: {two_stage_config}")
            for key, value in two_stage_config.items():
                if hasattr(args, key) and getattr(args, key) == parser.get_default(key):
                    print(f"⚙️  Setting {key} = {value}")
                    setattr(args, key, value)
                elif hasattr(args, key):
                    print(f"⚠️  Keeping CLI/default value for {key}: {getattr(args, key)} (config has {value})")
                else:
                    print(f"⚠️  Unknown config key in two_stage: {key} = {value}")
        
        # Handle three-stage specific parameters
        if 'three_stage' in config_data:
            three_stage_config = config_data['three_stage']
            print(f"📄 Loading three_stage config: {three_stage_config}")
            for key, value in three_stage_config.items():
                if hasattr(args, key) and getattr(args, key) == parser.get_default(key):
                    print(f"⚙️  Setting {key} = {value}")
                    setattr(args, key, value)
                elif hasattr(args, key):
                    print(f"⚠️  Keeping CLI/default value for {key}: {getattr(args, key)} (config has {value})")
                else:
                    print(f"⚠️  Unknown config key in three_stage: {key} = {value}")
        
        # Handle reproducibility config
        if 'reproducibility' in config_data:
            repro_config = config_data['reproducibility']
            for key, value in repro_config.items():
                if hasattr(args, key) and getattr(args, key) == parser.get_default(key):
                    setattr(args, key, value)
        
        # Handle top-level config keys as fallback
        for key, value in config_data.items():
            if key not in ['evaluation', 'framework', 'adaptive', 'reproducibility', 'logging'] and hasattr(args, key):
                arg_value = getattr(args, key)
                # Update if None or if it's a parameter with default that matches parser default
                if arg_value is None or (key in ['judge', 'datasets'] and arg_value == parser.get_default(key)):
                    setattr(args, key, value)
    
    # Set defaults for per-model parameters if not already set
    if not hasattr(args, 'vlm_temperature'):
        args.vlm_temperature = 1.0  # Legacy VLM default
    if not hasattr(args, 'vlm_top_p'):
        args.vlm_top_p = 0.001  # Legacy VLM default
    if not hasattr(args, 'vlm_top_k'):
        args.vlm_top_k = 1  # Legacy VLM default
    if not hasattr(args, 'reasoner_temperature'):
        args.reasoner_temperature = args.temperature
    if not hasattr(args, 'reasoner_top_p'):
        args.reasoner_top_p = args.top_p
    if not hasattr(args, 'reasoner_top_k'):
        args.reasoner_top_k = args.top_k
    
    # Set defaults for confidence experiment flags if not already set
    if not hasattr(args, 'enable_vlm_confidence'):
        args.enable_vlm_confidence = False
    if not hasattr(args, 'use_confidence_in_reasoner'):
        args.use_confidence_in_reasoner = False
    if not hasattr(args, 'use_logprobs_confidence'):
        args.use_logprobs_confidence = False
    
    # Set defaults for clarification experiment flags if not already set
    if not hasattr(args, 'deny_clarifications'):
        args.deny_clarifications = False
    if not hasattr(args, 'track_clarification_requests'):
        args.track_clarification_requests = False
    
    # Validate required arguments
    if not args.reasoning_approach:
        parser.error("--reasoning_approach is required (or specify via config file)")
    if not args.vlm_api_base:
        parser.error("--vlm_api_base is required")
    
    # Reasoner parameters are only required for reasoning frameworks (not VLM-only)
    if args.reasoning_approach != 'vlm_only':
        if not args.reasoner_api_base:
            parser.error("--reasoner_api_base is required for reasoning frameworks")
        if not args.reasoner_model:
            parser.error("--reasoner_model is required for reasoning frameworks")
        if not args.reasoner_type:
            parser.error("--reasoner_type is required for reasoning frameworks")
    
    # Run sanity checks
    if not args.skip_sanity_checks and not run_sanity_checks(args):
        print("❌ Sanity checks failed. Aborting evaluation.")
        sys.exit(1)
    
    # Run evaluation
    run_evaluation(args)


if __name__ == "__main__":
    main() 