#!/usr/bin/env python3
"""
LLMBar Flexible Evaluation Script

Evaluates datasets using flexible dimension selection,
metrics+reference approach, and expert model evaluation.
"""

import os
import json
import time
import logging
import argparse
import asyncio
import aiohttp
import re
import pickle
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
from typing import List, Dict, Any, Optional
from tqdm import tqdm
import yaml

class GlobalRateLimitHandler:
    def __init__(self):
        self.rate_limited_until = 0
        self.lock = asyncio.Lock()
    
    async def check_and_wait_if_rate_limited(self):
        async with self.lock:
            current_time = time.time()
            if current_time < self.rate_limited_until:
                wait_time = self.rate_limited_until - current_time
                logging.warning(f"Global rate limit active. Waiting {wait_time:.1f} seconds...")
                await asyncio.sleep(wait_time)
    
    async def set_rate_limited(self, wait_duration: int = 60):
        async with self.lock:
            current_time = time.time()
            self.rate_limited_until = current_time + wait_duration
            logging.warning(f"Global rate limit set. All requests will wait {wait_duration} seconds.")

class DualAPIKeyManager:
    def __init__(self, api_key_1: str, api_key_2: str):
        self.api_keys = [api_key_1, api_key_2] if api_key_2 else [api_key_1]
        self.current_index = 0
        self.lock = asyncio.Lock()
        self.valid_keys = []
        self.invalid_keys = []
        
        self.api_keys = [key for key in self.api_keys if key and key.strip()]
        
        if not self.api_keys:
            raise ValueError("At least one valid API key must be provided")
        
        logging.info(f"Initialized API key manager with {len(self.api_keys)} key(s)")
        if len(self.api_keys) == 1:
            logging.info("Single API key mode - will use rate limiting for 10 RPM")
    
    async def get_next_api_key(self) -> str:
        async with self.lock:
            if self.valid_keys:
                api_key = self.valid_keys[self.current_index % len(self.valid_keys)]
                self.current_index = (self.current_index + 1) % len(self.valid_keys)
                return api_key
            
            api_key = self.api_keys[self.current_index]
            self.current_index = (self.current_index + 1) % len(self.api_keys)
            return api_key
    
    def get_api_key_for_worker(self, worker_id: int) -> str:
        return self.api_keys[worker_id % len(self.api_keys)]
    
    async def mark_key_as_valid(self, api_key: str):
        async with self.lock:
            if api_key not in self.valid_keys:
                self.valid_keys.append(api_key)
                logging.info(f"API key marked as valid")
    
    async def mark_key_as_invalid(self, api_key: str):
        async with self.lock:
            if api_key not in self.invalid_keys:
                self.invalid_keys.append(api_key)
                logging.warning(f"API key marked as invalid")
            
            if api_key in self.valid_keys:
                self.valid_keys.remove(api_key)
    
    def get_valid_key_count(self) -> int:
        return len(self.valid_keys)
    
    def get_invalid_key_count(self) -> int:
        return len(self.invalid_keys)
    
    def is_single_key_mode(self) -> bool:
        return len(self.api_keys) == 1

@dataclass
class LLMBarInstance:
    instance_id: str
    input: str
    output_1: str
    output_2: str
    gold_label: int

@dataclass
class EvaluationResult:
    instance_id: str
    input: str
    output_1: str
    output_2: str
    gold_label: int
    predicted_label: int
    reasoning: str
    selected_dimensions: List[str]
    metrics_questions: List[str]
    reference_output: str
    dimension_evaluations: Dict[str, str]
    expert_models_used: List[str]
    evaluation_time: float
    correct: bool
    confidence: float = 0.0

rate_limit_handler = GlobalRateLimitHandler()

class SingleKeyRateLimiter:
    def __init__(self, requests_per_minute: int = 10):
        self.requests_per_minute = requests_per_minute
        self.min_interval = 60.0 / requests_per_minute
        self.last_request_time = 0
        self.lock = asyncio.Lock()
    
    async def wait_if_needed(self):
        async with self.lock:
            current_time = time.time()
            time_since_last = current_time - self.last_request_time
            
            if time_since_last < self.min_interval:
                wait_time = self.min_interval - time_since_last
                logging.debug(f"Rate limiting: waiting {wait_time:.2f} seconds")
                await asyncio.sleep(wait_time)
            
            self.last_request_time = time.time()

single_key_rate_limiter = SingleKeyRateLimiter(10)

def setup_logging() -> logging.Logger:
    detailed_logger = logging.getLogger('detailed')
    detailed_logger.setLevel(logging.INFO)
    
    for handler in detailed_logger.handlers[:]:
        detailed_logger.removeHandler(handler)
    
    detailed_handler = logging.FileHandler('llmbar_flexible_evaluation_detailed.log', mode='w', encoding='utf-8')
    detailed_handler.setLevel(logging.INFO)
    
    formatter = logging.Formatter('%(levelname)s:%(name)s:%(message)s')
    detailed_handler.setFormatter(formatter)
    
    detailed_logger.addHandler(detailed_handler)
    
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    
    logging.basicConfig(
        level=logging.INFO,
        format='[%(levelname)s] %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("llmbar_flexible_evaluation.log", mode='w', encoding='utf-8')
        ]
    )
    
    return detailed_logger

def load_metrics_config(metrics_path: str) -> Dict[str, Any]:
    with open(metrics_path, 'r') as f:
        metrics_config = yaml.safe_load(f)
    return metrics_config

ALL_DIMENSIONS = [
    'Accuracy', 'Admit Uncertainty', 'Attractive', 'Audience Friendly', 'Authenticity',
    'Being Friendly', 'Citation', 'Clarity', 'Code Correctness', 'Code Readability',
    'Coherence', 'Completeness', 'Coverage', 'Creativity', 'Depth', 'Emojis',
    'Emotion', 'Faithfulness', 'Feasibility', 'Harmlessness', 'Information Richness',
    'Insight', 'Instruction Following', 'Interactivity', 'Layout', 'Length', 'Logic',
    'Modularity', 'Multiple Aspects', 'Objectivity', 'Originality', 'Pacing',
    'Pointing Out', 'Professional', 'Professionalism', 'Relevance', 'Result at the Beginning',
    'Step by Step Explanation', 'Style', 'Timeliness', 'Vivid'
]

DIMENSION_TO_MODEL_MAPPING = {
    'Accuracy': 'provider-3/qwen-2.5-72b',
    'Admit Uncertainty': 'provider-3/deepseek-v3',
    'Attractive': 'provider-3/llama-3.3-70b',
    'Audience Friendly': 'provider-3/gpt-5-chat',
    'Authenticity': 'provider-3/qwen-2.5-72b',
    'Being Friendly': 'provider-3/sonar-pro',
    'Citation': 'provider-6/gpt-4o',
    'Clarity': 'provider-3/kimi-k2',
    'Code Correctness': 'provider-3/qwen-2.5-72b',
    'Code Readability': 'provider-3/mistral-large-latest',
    'Coherence': 'provider-3/kimi-k2',
    'Completeness': 'provider-3/gpt-5-chat',
    'Coverage': 'provider-3/sonar-pro',
    'Creativity': 'provider-3/sonar',
    'Depth': 'provider-3/sonar-pro',
    'Emojis': 'provider-3/deepseek-v3',
    'Emotion': 'provider-3/gpt-5-chat',
    'Faithfulness': 'provider-3/llama-3.1-70b',
    'Feasibility': 'provider-3/gpt-5-chat',
    'Harmlessness': 'provider-3/sonar-pro',
    'Information Richness': 'provider-3/sonar-pro',
    'Insight': 'provider-3/sonar-pro',
    'Instruction Following': 'provider-3/gpt-5-chat',
    'Interactivity': 'provider-3/sonar-pro',
    'Layout': 'provider-6/llama-4-maverick',
    'Length': 'provider-3/mistral-medium-latest',
    'Logic': 'provider-3/kimi-k2',
    'Modularity': 'provider-3/llama-3.3-70b',
    'Multiple Aspects': 'provider-3/gpt-5-chat',
    'Objectivity': 'provider-6/o3-medium',
    'Originality': 'provider-3/sonar',
    'Pacing': 'provider-3/kimi-k2',
    'Pointing Out': 'provider-3/gpt-5-nano',
    'Professional': 'provider-3/qwen-2.5-72b',
    'Professionalism': 'provider-6/o3-high',
    'Relevance': 'provider-3/sonar-pro',
    'Result at the Beginning': 'provider-6/claude-sonnet-4-20250514-thinking',
    'Step by Step Explanation': 'provider-3/gpt-4.1-mini',
    'Style': 'provider-3/mistral-medium-latest',
    'Timeliness': 'provider-3/kimi-k2',
    'Vivid': 'provider-3/kimi-k2'
}

# Dynamic dimension selection prompt (based on research paper)
DIMENSION_SELECTION_PROMPT = """You are an expert evaluator specialized in selecting the most relevant evaluation dimensions for a given instruction.

Your task is to select 5-10 dimensions from the available list that are MOST RELEVANT for evaluating outputs for the given instruction.

AVAILABLE DIMENSIONS:
{all_dimensions}

INSTRUCTION:
{input}

SELECTION CRITERIA:
1. Choose dimensions that directly relate to whether the output correctly follows the instruction
2. Prioritize dimensions that measure task completion and instruction adherence
3. Include dimensions that assess the quality and appropriateness of the response
4. Select 5-10 dimensions (preferably 7-8 for balanced evaluation)

OUTPUT FORMAT:
Return ONLY a JSON array of dimension names, like:
["dimension1", "dimension2", "dimension3"]

Example:
["Instruction Following", "Accuracy", "Relevance", "Completeness", "Clarity"]

Your selection:"""

# Metrics generation prompt (based on research paper)
METRICS_GENERATION_PROMPT = """You are a helpful assistant in evaluating the quality of outputs for a given instruction.

Please propose at most three concise questions about whether a potential output is a good output for the given instruction. Another assistant will evaluate different aspects of the output by answering all the questions.

Here are some rules of the evaluation:
(1) You should prioritize evaluating whether the output honestly/precisely/closely executes the instruction.
(2) Outputs should NOT contain more/less than what the instruction asks for, as such outputs do NOT precisely execute the instruction.

# Instruction:
{input}

# Requirements for Your Output:
(1) The questions should **specifically** target the given instruction instead of some general standards, so the questions may revolve around key points of the instruction.
(2) You should directly give the questions without any other words.
(3) Questions are presented from most important to least important.

Your questions:"""

# Reference generation prompt (based on research paper)
REFERENCE_GENERATION_PROMPT = """You are a helpful assistant that responds to the user in a concise way.

{input}

Your response:"""

# Main evaluation prompt (based on research paper)
MAIN_EVALUATION_PROMPT = """You are a helpful assistant in evaluating the quality of the outputs for a given instruction. Your goal is to select the best output for the given instruction.

Select the Output (a) or Output (b) that is better for the given instruction. The two outputs are generated by two different AI chatbots respectively.

CRITICAL EVALUATION RULES:
(1) **TASK CORRECTNESS FIRST**: Prioritize whether the output honestly/precisely/closely executes the instruction above all else.
(2) **INSTRUCTION ADHERENCE**: Outputs should NOT contain more/less than what the instruction asks for. Overly comprehensive or off-topic outputs are NOT better.
(3) **APPROPRIATE SCOPE**: The output should match the scope and specificity requested in the instruction.
(4) **QUALITY WITHIN CORRECTNESS**: Only consider helpfulness, accuracy, detail, etc. AFTER confirming task correctness.
(5) **AVOID BIAS**: The order of presentation should NOT affect your judgment. Output (a) and Output (b) are equally likely to be better.

IMPORTANT: 
- Do NOT favor longer, more comprehensive outputs if they don't match the instruction
- Do NOT prioritize "helpfulness" over task correctness
- Focus on whether the output correctly addresses the specific task requested

Do NOT provide any explanation for your choice.
Do NOT say both / neither are good.
You should answer using ONLY "Output (a)" or "Output (b)". Do NOT output any other words.

# Instruction:
{input}

# Output (a):
{output_1}

# Output (b):
{output_2}

# Questions about Outputs:
Here are at most three questions about the outputs, which are presented from most important to least important. You can do the evaluation based on thinking about all the questions.
{metrics_questions}

# A reference output generated by a strong AI assistant:
{reference_output}

# Which is better, Output (a) or Output (b)? Your response should be either "Output (a)" or "Output (b)":"""

# Dimension evaluation prompt
DIMENSION_EVALUATION_PROMPT = """You are an expert evaluator specialized in assessing a specific dimension of output quality.

Your task is to evaluate which output (1 or 2) better fulfills the given instruction according to the specified dimension.

INSTRUCTION: {input}

OUTPUT 1: {output_1}

OUTPUT 2: {output_2}

DIMENSION TO EVALUATE: {dimension}

DIMENSION DESCRIPTION: {dimension_description}

CRITICAL EVALUATION CRITERIA (IN ORDER OF PRIORITY):
1. **TASK CORRECTNESS FIRST**: Does the output actually do what was asked for in the instruction?
2. **INSTRUCTION ADHERENCE**: Does the output follow the exact instructions given?
3. **APPROPRIATE SCOPE**: Does the output provide what was requested without adding unnecessary extras?
4. **DIMENSION-SPECIFIC QUALITY**: Only AFTER confirming 1-3, evaluate the {dimension} aspect

VALIDATION STEPS:
- First, verify: Does Output 1 correctly address the instruction?
- Then, verify: Does Output 2 correctly address the instruction?
- If one output is factually incorrect or off-topic, it CANNOT win on any dimension
- If both outputs are correct, then evaluate the {dimension} aspect

IMPORTANT: 
- Do NOT prioritize comprehensiveness over task correctness
- Do NOT favor longer outputs if they don't match the instruction
- Do NOT give high scores to outputs that are factually wrong or off-topic
- Focus on whether the output correctly addresses the specific task requested

Your evaluation: Respond with ONLY "1" or "2" followed by a brief explanation of why that output better fulfills the task according to the {dimension} dimension."""

# Load LLMBar dataset
def load_llmbar_dataset(dataset_path: str, max_instances: Optional[int] = None) -> List[LLMBarInstance]:
    """Load LLMBar dataset from JSON file"""
    try:
        with open(dataset_path, 'r', encoding='utf-8') as f:
            data = json.load(f)
        
        instances = []
        for i, item in enumerate(data):
            instance = LLMBarInstance(
                instance_id=item.get('id', f'instance_{i:04d}'),
                input=item.get('input', ''),
                output_1=item.get('output_1', ''),
                output_2=item.get('output_2', ''),
                gold_label=item.get('label', 1)
            )
            instances.append(instance)
            
            if max_instances and len(instances) >= max_instances:
                break
        
        logging.info(f"Loaded {len(instances)} instances from dataset")
        return instances
    except Exception as e:
        logging.error(f"Failed to load dataset: {e}")
        return []

# Checkpoint and resume functionality
def save_checkpoint(checkpoint_path: str, results: List[EvaluationResult], processed_instances: set, total_instances: int):
    """Save checkpoint with current progress"""
    checkpoint_data = {
        'results': results,
        'processed_instances': list(processed_instances),
        'total_instances': total_instances,
        'timestamp': time.time()
    }
    
    try:
        with open(checkpoint_path, 'wb') as f:
            pickle.dump(checkpoint_data, f)
        logging.info(f"Checkpoint saved to {checkpoint_path}")
    except Exception as e:
        logging.error(f"Failed to save checkpoint: {e}")

def load_checkpoint(checkpoint_path: str) -> tuple[List[EvaluationResult], set, int]:
    """Load checkpoint and return results, processed instances, and total instances"""
    try:
        if os.path.exists(checkpoint_path):
            with open(checkpoint_path, 'rb') as f:
                checkpoint_data = pickle.load(f)
            
            results = checkpoint_data['results']
            processed_instances = set(checkpoint_data['processed_instances'])
            total_instances = checkpoint_data['total_instances']
            
            logging.info(f"Checkpoint loaded from {checkpoint_path}")
            logging.info(f"Resuming with {len(results)} completed evaluations and {len(processed_instances)} processed instances")
            return results, processed_instances, total_instances
        else:
            logging.info("No checkpoint found, starting fresh")
            return [], set(), 0
    except Exception as e:
        logging.error(f"Failed to load checkpoint: {e}")
        return [], set(), 0

def get_checkpoint_path(output_path: str) -> str:
    """Get checkpoint file path based on output path"""
    return f"{output_path}_checkpoint.pkl"

async def call_api_async(session: aiohttp.ClientSession, messages: List[Dict], model: str, api_key: str, base_url: str, max_tokens: int = 150, temperature: float = 0.0, api_key_manager: DualAPIKeyManager = None) -> Dict:
    """Make API call to the specified model"""
    try:
        # Apply rate limiting if in single key mode
        if api_key_manager and api_key_manager.is_single_key_mode():
            await single_key_rate_limiter.wait_if_needed()
        
        payload = {
            "model": model,
            "messages": messages,
            "temperature": temperature,
            "max_tokens": max_tokens
        }
        
        async with session.post(
            f"{base_url}/chat/completions",
            json=payload,
            headers={"Authorization": f"Bearer {api_key}"},
            ssl=False
        ) as response:
            if response.status == 200:
                result = await response.json()
                content = result['choices'][0]['message']['content'].strip()
                return {"success": True, "content": content}
            elif response.status == 401:
                # Authentication error - mark key as invalid
                logging.error(f"HTTP 401 (Unauthorized) for model {model} - API key may be invalid or expired")
                if api_key_manager:
                    await api_key_manager.mark_key_as_invalid(api_key)
                return {"success": False, "error": f"HTTP {response.status}", "auth_error": True}
            elif response.status == 403:
                # Forbidden - model access issue
                logging.error(f"HTTP 403 (Forbidden) for model {model} - API key may not have access to this model")
                if api_key_manager:
                    await api_key_manager.mark_key_as_invalid(api_key)
                return {"success": False, "error": f"HTTP {response.status}", "auth_error": True}
            elif response.status == 429:
                # Rate limited
                logging.warning(f"HTTP 429 (Rate Limited) for model {model}")
                return {"success": False, "error": f"HTTP {response.status}", "rate_limited": True}
            else:
                logging.error(f"HTTP {response.status} for model {model}")
                return {"success": False, "error": f"HTTP {response.status}"}
    except Exception as e:
        logging.error(f"Exception during API call to {model}: {e}")
        return {"success": False, "error": str(e)}

async def validate_api_key(session: aiohttp.ClientSession, api_key: str, base_url: str, api_key_manager: DualAPIKeyManager) -> bool:
    """Validate an API key by making a simple test call"""
    try:
        # Use a simple model for validation
        test_model = "provider-3/kimi-k2"
        messages = [
            {"role": "user", "content": "Hello, this is a test message. Please respond with 'OK'."}
        ]
        
        result = await call_api_async(session, messages, test_model, api_key, base_url, max_tokens=10, temperature=0.0, api_key_manager=api_key_manager)
        
        if result["success"]:
            await api_key_manager.mark_key_as_valid(api_key)
            logging.info(f"API key validation successful")
            return True
        elif result.get("auth_error"):
            await api_key_manager.mark_key_as_invalid(api_key)
            logging.error(f"API key validation failed: Authentication error")
            return False
        else:
            logging.warning(f"API key validation failed: {result.get('error', 'Unknown error')}")
            return False
            
    except Exception as e:
        logging.error(f"Exception during API key validation: {e}")
        await api_key_manager.mark_key_as_invalid(api_key)
        return False

# Global dimension usage tracking
class DimensionUsageTracker:
    def __init__(self, all_dimensions: List[str]):
        self.all_dimensions = all_dimensions
        self.usage_counts = {dim: 0 for dim in all_dimensions}
        self.recent_selections = []  # Track recent selections to avoid repetition
        self.lock = asyncio.Lock()
    
    async def record_usage(self, selected_dimensions: List[str]):
        """Record which dimensions were used"""
        async with self.lock:
            for dim in selected_dimensions:
                if dim in self.usage_counts:
                    self.usage_counts[dim] += 1
            self.recent_selections.append(selected_dimensions)
            # Keep only last 10 selections to avoid memory bloat
            if len(self.recent_selections) > 10:
                self.recent_selections.pop(0)
    
    async def get_diversity_boosted_dimensions(self, base_selection: List[str], target_count: int = 10) -> List[str]:
        """Boost dimension diversity by replacing some frequently used dimensions"""
        async with self.lock:
            # Sort dimensions by usage (ascending - least used first)
            sorted_dimensions = sorted(self.usage_counts.items(), key=lambda x: x[1])
            least_used = [dim for dim, count in sorted_dimensions[:15]]  # Top 15 least used
            
            # Find dimensions in base selection that are overused
            overused_threshold = max(1, max(self.usage_counts.values()) // 2)
            overused_in_selection = [dim for dim in base_selection if self.usage_counts.get(dim, 0) > overused_threshold]
            
            # Replace overused dimensions with least used ones
            final_selection = base_selection.copy()
            replacement_count = min(len(overused_in_selection), target_count - len(base_selection))
            
            for i, overused_dim in enumerate(overused_in_selection[:replacement_count]):
                if i < len(least_used):
                    replacement_dim = least_used[i]
                    if replacement_dim not in final_selection:
                        final_selection[final_selection.index(overused_dim)] = replacement_dim
                        logging.info(f"Diversity boost: Replaced {overused_dim} with {replacement_dim}")
            
            return final_selection[:target_count]
    
    async def get_rotation_dimensions(self, instance_id: str, target_count: int = 10) -> List[str]:
        """Use rotation strategy to ensure all dimensions get used eventually"""
        async with self.lock:
            # Use instance ID to determine rotation offset
            rotation_offset = hash(instance_id) % len(self.all_dimensions)
            
            # Create a rotated list starting from the offset
            rotated_dimensions = self.all_dimensions[rotation_offset:] + self.all_dimensions[:rotation_offset]
            
            # Select dimensions with some randomness
            import random
            random.seed(hash(instance_id))  # Deterministic but varied per instance
            
            # Select dimensions ensuring category diversity
            categories = {
                'task_core': ['Instruction Following', 'Accuracy', 'Relevance', 'Completeness'],
                'quality': ['Clarity', 'Coherence', 'Logic', 'Depth'],
                'style': ['Style', 'Professionalism', 'Audience Friendly', 'Creativity'],
                'technical': ['Code Correctness', 'Code Readability', 'Modularity', 'Feasibility'],
                'content': ['Coverage', 'Information Richness', 'Insight', 'Originality']
            }
            
            selected = []
            for category, dims in categories.items():
                # Select 1-2 dimensions from each category
                category_selection = random.sample(dims, min(2, len(dims)))
                selected.extend(category_selection)
            
            # Fill remaining slots with least used dimensions
            remaining_slots = target_count - len(selected)
            if remaining_slots > 0:
                sorted_by_usage = sorted(self.usage_counts.items(), key=lambda x: x[1])
                least_used = [dim for dim, count in sorted_by_usage if dim not in selected]
                selected.extend(least_used[:remaining_slots])
            
            return selected[:target_count]

# Global dimension tracker
dimension_tracker = DimensionUsageTracker(ALL_DIMENSIONS)

async def select_relevant_dimensions(session: aiohttp.ClientSession, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> List[str]:
    """Use kimi-k2 to dynamically select relevant dimensions - NO BIAS CORRECTION"""
    
    # Format all dimensions for the prompt
    all_dimensions_text = "\n".join([f"- {dim}" for dim in ALL_DIMENSIONS])
    
    # Simple, unbiased prompt - let LLM choose freely
    simple_prompt = f"""You are an expert evaluator specialized in selecting the most relevant evaluation dimensions for a given instruction.

Your task is to select 5-10 dimensions from the available list that are MOST RELEVANT for evaluating outputs for the given instruction.

AVAILABLE DIMENSIONS:
{all_dimensions_text}

INSTRUCTION:
{instance.input}

SELECTION CRITERIA:
1. Choose dimensions that directly relate to whether the output correctly follows the instruction
2. Prioritize dimensions that measure task completion and instruction adherence
3. Include dimensions that assess the quality and appropriateness of the response
4. Select 5-10 dimensions (preferably 7-8 for balanced evaluation)

OUTPUT FORMAT:
Return ONLY a JSON array of dimension names, like:
["dimension1", "dimension2", "dimension3"]

Example:
["Instruction Following", "Accuracy", "Relevance", "Completeness", "Clarity"]

Your selection:"""
    
    messages = [
        {"role": "system", "content": "You are an expert evaluator specialized in selecting relevant evaluation dimensions."},
        {"role": "user", "content": simple_prompt}
    ]
    
    api_key = await api_key_manager.get_next_api_key()
    
    result = await call_api_async(
        session, messages, "provider-3/kimi-k2", api_key, base_url, 
        max_tokens=200, temperature=0.1, api_key_manager=api_key_manager  # Lower temperature for more consistent selection
    )
    
    if result["success"]:
        try:
            # Try to extract JSON from the response
            content = result["content"]
            
            # Look for JSON array pattern
            json_match = re.search(r'\[.*?\]', content, re.DOTALL)
            if json_match:
                json_str = json_match.group(0)
                selected_dimensions = json.loads(json_str)
                
                # Validate that all selected dimensions exist in our mapping
                valid_dimensions = [dim for dim in selected_dimensions if dim in DIMENSION_TO_MODEL_MAPPING]
                
                if len(valid_dimensions) >= 3:
                    await dimension_tracker.record_usage(valid_dimensions)
                    return valid_dimensions[:10]
                else:
                    logging.warning(f"Too few valid dimensions selected: {valid_dimensions}")
            
            fallback_dimensions = ['Instruction Following', 'Accuracy', 'Relevance', 'Completeness', 'Clarity']
            await dimension_tracker.record_usage(fallback_dimensions)
            return fallback_dimensions
            
        except Exception as e:
            logging.error(f"Error parsing dimension selection: {e}")
            fallback_dimensions = ['Instruction Following', 'Accuracy', 'Relevance', 'Completeness', 'Clarity']
            await dimension_tracker.record_usage(fallback_dimensions)
            return fallback_dimensions
    
    fallback_dimensions = ['Instruction Following', 'Accuracy', 'Relevance', 'Completeness', 'Clarity']
    await dimension_tracker.record_usage(fallback_dimensions)
    return fallback_dimensions

async def generate_metrics_questions(session: aiohttp.ClientSession, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> List[str]:
    """Generate metrics questions using the approach from the research paper"""
    
    prompt = METRICS_GENERATION_PROMPT.format(input=instance.input)
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant in evaluating output quality."},
        {"role": "user", "content": prompt}
    ]
    
    api_key = await api_key_manager.get_next_api_key()
    
    result = await call_api_async(
        session, messages, "provider-3/kimi-k2", api_key, base_url, 
        max_tokens=150, temperature=0.0, api_key_manager=api_key_manager
    )
    
    if result["success"]:
        try:
            content = result["content"]
            # Split by lines and clean up
            questions = [q.strip() for q in content.split('\n') if q.strip()]
            # Limit to 3 questions
            questions = questions[:3]
            return questions
        except Exception as e:
            logging.error(f"Error parsing metrics questions: {e}")
    
    fallback_questions = [
        "Does the output correctly address the instruction?",
        "Is the output complete and accurate?",
        "Is the output helpful and well-structured?"
    ]
    return fallback_questions

async def generate_reference_output(session: aiohttp.ClientSession, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> str:
    """Generate a reference output using the approach from the research paper"""
    
    prompt = REFERENCE_GENERATION_PROMPT.format(input=instance.input)
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant that responds to the user in a concise way."},
        {"role": "user", "content": prompt}
    ]
    
    api_key = await api_key_manager.get_next_api_key()
    
    result = await call_api_async(
        session, messages, "provider-3/kimi-k2", api_key, base_url, 
        max_tokens=384, temperature=0.0, api_key_manager=api_key_manager
    )
    
    if result["success"]:
        return result["content"]
    
    fallback_reference = "A well-structured response that directly addresses the instruction with appropriate detail and accuracy."
    return fallback_reference

async def evaluate_dimension(session: aiohttp.ClientSession, dimension: str, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> Dict:
    """Evaluate a single dimension using the expert model"""
    
    # Get expert model for this dimension
    expert_model = DIMENSION_TO_MODEL_MAPPING.get(dimension)
    if not expert_model:
        return {"dimension": dimension, "evaluation": "tie", "model": "unknown"}
    
    # Create dimension description
    dimension_description = f"Evaluate the {dimension} aspect of the output quality"
    
    prompt = DIMENSION_EVALUATION_PROMPT.format(
        input=instance.input,
        output_1=instance.output_1,
        output_2=instance.output_2,
        dimension=dimension,
        dimension_description=dimension_description
    )
    
    messages = [
        {"role": "system", "content": "You are an expert evaluator specialized in assessing output quality dimensions."},
        {"role": "user", "content": prompt}
    ]
    
    api_key = await api_key_manager.get_next_api_key()
    
    result = await call_api_async(
        session, messages, expert_model, api_key, base_url, 
        max_tokens=100, temperature=0.1, api_key_manager=api_key_manager
    )
    
    if result["success"]:
        content = result["content"]
        # Parse the response to get the evaluation result
        if content.startswith("1"):
            evaluation = "1"
        elif content.startswith("2"):
            evaluation = "2"
        else:
            evaluation = "tie"
        
        return {
            "dimension": dimension,
            "evaluation": evaluation,
            "model": expert_model,
            "raw_response": content
        }
    elif result.get("auth_error"):
        # Mark the API key as invalid
        await api_key_manager.mark_key_as_invalid(api_key)
        logging.error(f"Authentication error for dimension {dimension} with model {expert_model}")
        return {
            "dimension": dimension,
            "evaluation": "tie",
            "model": expert_model,
            "raw_response": "Error: Authentication failed"
        }
    elif result.get("rate_limited"):
        logging.warning(f"Rate limited for dimension {dimension} with model {expert_model}")
        return {
            "dimension": dimension,
            "evaluation": "tie",
            "model": expert_model,
            "raw_response": "Error: Rate limited"
        }
    
    return {
        "dimension": dimension,
        "evaluation": "tie",
        "model": expert_model,
        "raw_response": "Error: Failed to evaluate"
    }



async def evaluate_instance_main(session: aiohttp.ClientSession, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> Dict:
    """Evaluate the instance using the main evaluation approach (Metrics+Reference)"""
    
    # Generate metrics questions
    metrics_questions = await generate_metrics_questions(session, instance, api_key_manager, base_url)
    
    # Generate reference output
    reference_output = await generate_reference_output(session, instance, api_key_manager, base_url)
    
    # Format metrics questions for the prompt
    metrics_text = "\n".join([f"{i+1}. {q}" for i, q in enumerate(metrics_questions)])
    
    prompt = MAIN_EVALUATION_PROMPT.format(
        input=instance.input,
        output_1=instance.output_1,
        output_2=instance.output_2,
        metrics_questions=metrics_text,
        reference_output=reference_output
    )
    
    messages = [
        {"role": "system", "content": "You are a helpful assistant in evaluating output quality."},
        {"role": "user", "content": prompt}
    ]
    
    api_key = await api_key_manager.get_next_api_key()
    
    result = await call_api_async(
        session, messages, "provider-3/kimi-k2", api_key, base_url, 
        max_tokens=50, temperature=0.0, api_key_manager=api_key_manager
    )
    
    if result["success"]:
        content = result["content"]
        # Parse the response
        if "Output (a)" in content or "output (a)" in content:
            predicted_label = 1
        elif "Output (b)" in content or "output (b)" in content:
            predicted_label = 2
        else:
            predicted_label = 1  # Default to output 1
        
        return {
            "predicted_label": predicted_label,
            "metrics_questions": metrics_questions,
            "reference_output": reference_output,
            "raw_response": content
        }
    
    # Fallback
    return {
        "predicted_label": 1,
        "metrics_questions": metrics_questions,
        "reference_output": reference_output,
        "raw_response": "Error: Failed to evaluate"
    }

async def evaluate_instance_flexible(session: aiohttp.ClientSession, instance: LLMBarInstance, api_key_manager: DualAPIKeyManager, base_url: str) -> EvaluationResult:
    """Evaluate an instance using the flexible dimension selection approach"""
    
    start_time = time.time()
    
    selected_dimensions = await select_relevant_dimensions(session, instance, api_key_manager, base_url)
    metrics_questions = await generate_metrics_questions(session, instance, api_key_manager, base_url)
    reference_output = await generate_reference_output(session, instance, api_key_manager, base_url)
    dimension_tasks = [
        evaluate_dimension(session, dim, instance, api_key_manager, base_url)
        for dim in selected_dimensions
    ]
    
    dimension_results = await asyncio.gather(*dimension_tasks, return_exceptions=True)
    
    # Process dimension results
    dimension_evaluations = {}
    expert_models_used = []
    
    for i, result in enumerate(dimension_results):
        if isinstance(result, Exception):
            logging.error(f"Error evaluating dimension {selected_dimensions[i]}: {result}")
            dimension_evaluations[selected_dimensions[i]] = "tie"
        else:
            dimension_evaluations[result["dimension"]] = result["evaluation"]
            if result["model"]:
                expert_models_used.append(result["model"])
    
    main_result = await evaluate_instance_main(session, instance, api_key_manager, base_url)
    
    # Count dimension wins
    output_1_wins = sum(1 for v in dimension_evaluations.values() if v == "1")
    output_2_wins = sum(1 for v in dimension_evaluations.values() if v == "2")
    ties = sum(1 for v in dimension_evaluations.values() if v == "tie")
    
    if output_1_wins > output_2_wins:
        dimension_winner = 1
    elif output_2_wins > output_1_wins:
        dimension_winner = 2
    else:
        dimension_winner = main_result["predicted_label"]
    
    final_prediction = dimension_winner
    
    # Calculate confidence
    total_dimensions = len(selected_dimensions)
    if total_dimensions > 0:
        confidence = max(output_1_wins, output_2_wins) / total_dimensions
    else:
        confidence = 0.5
    
    reasoning = f"Dimension wins: Output 1 ({output_1_wins}), Output 2 ({output_2_wins}), Ties ({ties}). Main evaluation: {main_result['predicted_label']}"
    
    return EvaluationResult(
        instance_id=instance.instance_id,
        input=instance.input,
        output_1=instance.output_1,
        output_2=instance.output_2,
        gold_label=instance.gold_label,
        predicted_label=final_prediction,
        reasoning=reasoning,
        selected_dimensions=selected_dimensions,
        metrics_questions=metrics_questions,
        reference_output=reference_output,
        dimension_evaluations=dimension_evaluations,
        expert_models_used=expert_models_used,
        evaluation_time=time.time() - start_time,
        correct=(final_prediction == instance.gold_label),
        confidence=confidence
    )

async def check_api_key_health(api_key_manager: DualAPIKeyManager) -> bool:
    """Check if we have enough valid API keys to continue"""
    valid_count = api_key_manager.get_valid_key_count()
    invalid_count = api_key_manager.get_invalid_key_count()
    
    if valid_count == 0:
        logging.error("No valid API keys remaining. Cannot continue evaluation.")
        return False
    
    if invalid_count > 0:
        logging.warning(f"Some API keys are invalid: {valid_count} valid, {invalid_count} invalid")
    
    return True

async def process_instances_flexible(instances: List[LLMBarInstance], api_key_manager: DualAPIKeyManager, base_url: str, max_concurrent: int = 19, session: aiohttp.ClientSession = None) -> List[EvaluationResult]:
    """Process instances using the flexible dimension selection approach"""
    
    results = []
    
    for i, instance in enumerate(instances):
        # Check API key health before processing each instance
        if not await check_api_key_health(api_key_manager):
            logging.error(f"Stopping evaluation due to API key issues at instance {i+1}")
            break
            
        logging.info(f"Processing instance {i+1}/{len(instances)}: {instance.instance_id}")
        
        # Evaluate this instance
        result = await evaluate_instance_flexible(session, instance, api_key_manager, base_url)
        results.append(result)
        

        
        if i < len(instances) - 1:
            await asyncio.sleep(60)
    
    return results

def save_results(results: List[EvaluationResult], output_path: str):
    """Save results to JSON files"""
    # Save detailed results
    results_data = []
    for result in results:
        results_data.append({
            "instance_id": result.instance_id,
            "input": result.input,
            "output_1": result.output_1,
            "output_2": result.output_2,
            "gold_label": result.gold_label,
            "predicted_label": result.predicted_label,
            "confidence": result.confidence,
            "selected_dimensions": result.selected_dimensions,
            "metrics_questions": result.metrics_questions,
            "reference_output": result.reference_output,
            "dimension_evaluations": result.dimension_evaluations,
            "expert_models_used": result.expert_models_used,
            "evaluation_time": result.evaluation_time,
            "reasoning": result.reasoning,
            "correct": result.correct
        })
    
    with open(f"{output_path}_results.json", "w", encoding="utf-8") as f:
        json.dump(results_data, f, indent=2, ensure_ascii=False)
    
    metrics = {
        "overall_accuracy": sum(1 for r in results if r.correct) / len(results) if results else 0,
        "total_instances": len(results),
        "correct_predictions": sum(1 for r in results if r.correct),
        "incorrect_predictions": sum(1 for r in results if not r.correct),
        "average_confidence": sum(r.confidence for r in results) / len(results) if results else 0,
        "average_dimensions_per_instance": sum(len(r.selected_dimensions) for r in results) / len(results) if results else 0
    }
    
    with open(f"{output_path}_metrics.json", "w", encoding="utf-8") as f:
        json.dump(metrics, f, indent=2, ensure_ascii=False)

def print_summary_table(results: List[EvaluationResult]):
    if not results:
        return
    
    overall_correct = sum(1 for r in results if r.correct)
    overall_total = len(results)
    overall_accuracy = (overall_correct / overall_total) * 100
    
    print(f"Accuracy: {overall_accuracy:.1f}% ({overall_correct}/{overall_total})")



def parse_args():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="LLMBar Flexible Evaluation Script with Dynamic Dimension Selection")
    parser.add_argument(
        "--dataset_path",
        type=str,
        required=True,
        help="Path to LLMBar dataset JSON file"
    )
    parser.add_argument(
        "--api_key_1",
        type=str,
        required=True,
        help="First API key"
    )
    parser.add_argument(
        "--api_key_2",
        type=str,
        required=True,
        help="Second API key"
    )
    parser.add_argument(
        "--base_url",
        type=str,
        default="https://api.openai.com/v1",
        help="API base URL"
    )
    parser.add_argument(
        "--max_concurrent",
        type=int,
        default=19,
        help="Maximum concurrent API calls"
    )
    parser.add_argument(
        "--output_path",
        type=str,
        default="llmbar_flexible_evaluation",
        help="Output path prefix for results and metrics"
    )
    parser.add_argument(
        "--max_instances",
        type=int,
        default=0,
        help="Limit processing to the first N instances (0 means all)"
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=10,
        help="Number of instances to process in each chunk before saving checkpoint"
    )
    parser.add_argument(
        "--resume",
        action='store_true',
        help="Resume from checkpoint if available"
    )
    
    return parser.parse_args()

async def main():
    """Main function to run the flexible LLMBar evaluation"""
    args = parse_args()
    
    # Load dataset
    instances = load_llmbar_dataset(args.dataset_path, args.max_instances)
    if not instances:
        print("No instances loaded. Exiting.")
        return
    
    print(f"[INFO] Loaded {len(instances)} instances from dataset")
    
    # Initialize checkpoint handling
    checkpoint_path = get_checkpoint_path(args.output_path)
    results = []
    processed_instances = set()
    
    if args.resume:
        # Try to load existing checkpoint
        results, processed_instances, _ = load_checkpoint(checkpoint_path)
        if results:
            print(f"[INFO] Resuming from checkpoint with {len(results)} completed evaluations")
            # Filter out already processed instances
            instances = [inst for inst in instances if inst.instance_id not in processed_instances]
            print(f"[INFO] {len(instances)} instances remaining to process")
    
    if not instances:
        print("No instances remaining to process. Evaluation complete!")
        return
    
    # Initialize API key manager
    api_key_manager = DualAPIKeyManager(args.api_key_1, args.api_key_2)
    
    # Set up logging
    detailed_logger = setup_logging()
    
    import ssl
    ssl_context = ssl.create_default_context()
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE
    
    connector = aiohttp.TCPConnector(limit=args.max_concurrent, limit_per_host=args.max_concurrent, ssl=ssl_context)
    timeout = aiohttp.ClientTimeout(total=60)
    
    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        key1_valid = False
        key2_valid = False
        
        if args.api_key_1 and args.api_key_1.strip():
            key1_valid = await validate_api_key(session, args.api_key_1, args.base_url, api_key_manager)
        
        if args.api_key_2 and args.api_key_2.strip():
            key2_valid = await validate_api_key(session, args.api_key_2, args.base_url, api_key_manager)
        
        if not key1_valid and not key2_valid:
            print("[ERROR] No valid API keys found. Please check your API keys and try again.")
            return
        
        valid_keys = api_key_manager.get_valid_key_count()
        invalid_keys = api_key_manager.get_invalid_key_count()
        
        if valid_keys == 0:
            print("[ERROR] No valid API keys found. Cannot proceed with evaluation.")
            return
        
        total_instances = len(instances)
        
        for chunk_start in range(0, total_instances, args.chunk_size):
            chunk_end = min(chunk_start + args.chunk_size, total_instances)
            chunk_instances = instances[chunk_start:chunk_end]
            
            try:
                chunk_results = await process_instances_flexible(chunk_instances, api_key_manager, args.base_url, args.max_concurrent, session)
                
                results.extend(chunk_results)
                for result in chunk_results:
                    processed_instances.add(result.instance_id)
                
                save_checkpoint(checkpoint_path, results, processed_instances, total_instances)
                
            except Exception as e:
                print(f"[ERROR] Error processing chunk {chunk_start//args.chunk_size + 1}: {e}")
                print(f"[INFO] Checkpoint saved. You can resume with --resume flag")
                return
    
    # Save final results
    save_results(results, args.output_path)
    
    print_summary_table(results)
    
    if os.path.exists(checkpoint_path):
        os.remove(checkpoint_path)

if __name__ == "__main__":
    asyncio.run(main())
