#!/usr/bin/env python3

"""
Mixture of Experts (MoE) Candidate Profiling Script
Evaluates model responses using specialized expert models for different evaluation dimensions.
"""

import os
import json
import time
import logging
import argparse
import yaml
import asyncio
import aiohttp
from collections import defaultdict
from tqdm import tqdm
from typing import List, Dict, Any
import re
from datetime import datetime

# Global rate limit handler
class GlobalRateLimitHandler:
    def __init__(self):
        self.rate_limited_until = 0
        self.lock = asyncio.Lock()
    
    async def check_and_wait_if_rate_limited(self):
        """Check if we're in a rate limited period and wait if necessary"""
        async with self.lock:
            current_time = time.time()
            if current_time < self.rate_limited_until:
                wait_time = self.rate_limited_until - current_time
                logging.warning(f"Global rate limit active. Waiting {wait_time:.1f} seconds...")
                await asyncio.sleep(wait_time)
    
    async def set_rate_limited(self, wait_duration: int = 60):
        """Set global rate limit state"""
        async with self.lock:
            current_time = time.time()
            self.rate_limited_until = current_time + wait_duration
            logging.warning(f"Global rate limit set. All requests will wait {wait_duration} seconds.")

# Dual API key manager for round-robin distribution
class DualAPIKeyManager:
    def __init__(self, api_key_1: str, api_key_2: str):
        self.api_keys = [api_key_1, api_key_2]
        self.current_index = 0
        self.lock = asyncio.Lock()
    
    async def get_next_api_key(self) -> str:
        """Get the next API key in round-robin fashion"""
        async with self.lock:
            api_key = self.api_keys[self.current_index]
            self.current_index = (self.current_index + 1) % len(self.api_keys)
            return api_key
    
    def get_api_key_for_worker(self, worker_id: int) -> str:
        """Get API key based on worker ID for consistent distribution"""
        return self.api_keys[worker_id % len(self.api_keys)]

# Global instance
rate_limit_handler = GlobalRateLimitHandler()

# --- Logging Setup ---
def setup_logging():
    for handler in logging.root.handlers[:]:
        logging.root.removeHandler(handler)
    logging.basicConfig(
        level=logging.INFO,
        format='[%(levelname)s] %(message)s',
        handlers=[
            logging.StreamHandler(),
            logging.FileHandler("moe_candidate_profiling_500_winners.log", mode='w')
        ]
    )

# Load metrics configuration
def load_metrics_config(metrics_path: str) -> Dict[str, Any]:
    """Load metrics configuration from YAML file"""
    with open(metrics_path, 'r') as f:
        metrics_config = yaml.safe_load(f)
    return metrics_config

def get_scenario_dimensions_with_descriptions(metrics_config: Dict[str, Any], scenario: str) -> List[Dict[str, str]]:
    """Get dimensions with their descriptions for a specific scenario"""
    if scenario not in metrics_config:
        return []
    
    config = metrics_config[scenario]
    if 'attributes' not in config:
        return []
    
    attributes_text = config['attributes']
    dimensions = []
    
    # Parse each line that starts with "- "
    lines = attributes_text.split('\n')
    for line in lines:
        line = line.strip()
        if line.startswith('- '):
            # Extract dimension name and description
            # Format: "- DimensionName: Description"
            match = re.match(r'- (\w+(?:\s+\w+)*):\s*(.+)', line)
            if match:
                dimension_name = match.group(1).strip()
                description = match.group(2).strip()
                dimensions.append({
                    'name': dimension_name,
                    'description': description
                })
    
    return dimensions

def get_scenario_description(metrics_config: Dict[str, Any], scenario: str) -> str:
    """Get scenario description from metrics configuration"""
    if scenario not in metrics_config:
        return f"Scenario: {scenario}"
    
    config = metrics_config[scenario]
    return config.get('scenario', f'Scenario: {scenario}')

# Expert model mapping for different evaluation dimensions
DIMENSION_TO_MODEL_MAPPING = {
    'Accuracy': 'provider-3/qwen-2.5-72b',
    'Admit Uncertainty': 'provider-1/deepseek-r1-0528',
    'Attractive': 'provider-1/deepseek-r1-0528',
    'Audience Friendly': 'provider-1/deepseek-r1-0528',
    'Authenticity': 'provider-3/qwen-2.5-72b',
    'Being Friendly': 'provider-1/deepseek-r1-0528',
    'Citation': 'provider-1/deepseek-r1-0528',
    'Clarity': 'provider-3/kimi-k2',
    'Code Correctness': 'provider-3/qwen-2.5-72b',
    'Code Readability': 'provider-3/mistral-large-latest',
    'Coherence': 'provider-3/kimi-k2',
    'Completeness': 'provider-3/gpt-5-chat',
    'Coverage': 'provider-1/deepseek-r1-0528',
    'Creativity': 'provider-3/sonar',
    'Depth': 'provider-3/sonar-pro',
    'Emojis': 'provider-3/deepseek-v3',
    'Emotion': 'provider-3/gpt-5-chat',
    'Faithfulness': 'provider-3/llama-3.1-70b',
    'Feasibility': 'provider-3/gpt-5-chat',
    'Harmlessness': 'provider-3/sonar-pro',
    'Information Richness': 'provider-3/sonar-pro',
    'Insight': 'provider-1/deepseek-r1-0528',
    'Instruction Following': 'provider-3/gpt-5-chat',
    'Interactivity': 'provider-1/deepseek-r1-0528',
    'Layout': 'provider-6/llama-4-maverick',
    'Length': 'provider-3/mistral-medium-latest',
    'Logic': 'provider-3/kimi-k2',
    'Modularity': 'provider-3/llama-3.3-70b',
    'Multiple Aspects': 'provider-3/gpt-5-chat',
    'Objectivity': 'provider-6/o3-medium',
    'Originality': 'provider-1/deepseek-r1-0528',
    'Pacing': 'provider-3/kimi-k2',
    'Pointing Out': 'provider-3/gpt-5-nano',
    'Professional': 'provider-1/deepseek-r1-0528',
    'Professionalism': 'provider-6/o3-high',
    'Relevance': 'provider-3/sonar-pro',
    'Result at the Beginning': 'provider-6/claude-sonnet-4-20250514-thinking',
    'Step by Step Explanation': 'provider-3/gpt-4.1-mini',
    'Style': 'provider-3/claude-sonnet-4',
    'Timeliness': 'provider-3/kimi-k2',
    'Vivid': 'provider-3/kimi-k2'
}

# Prompts for different evaluation types
SYSTEM_PROMPT_TEMPLATE = """You are an expert judge for language model outputs. You will be given a prompt and two responses (Response 1 and Response 2) from different models. Your task is to evaluate which response is better for each specified dimension.

For each dimension, you must choose:
- "1" if Response 1 is better
- "2" if Response 2 is better  
- "tie" if both responses are equally good

IMPORTANT: Return ONLY a JSON object with the EXACT dimension names as keys. Do not use generic names like "dimension1", "dimension2", etc. Use the actual dimension names provided in the prompt.

Example format:
{"Accuracy": "1", "Clarity": "2", "Depth": "tie", "Relevance": "1"}

Do not include any explanations or extra text outside the JSON object."""

USER_PROMPT_TEMPLATE = """Scenario: {scenario_description}

Dimensions to evaluate:
{dimensions_with_descriptions}

Prompt: {prompt}

Response 1: {response_a}

Response 2: {response_b}

Evaluate which response is better for each dimension. Return ONLY a JSON object in this exact format:
{{"DimensionName1": "1/2/tie", "DimensionName2": "1/2/tie", ...}}

Where each dimension should be evaluated as:
- "1" if Response 1 is better
- "2" if Response 2 is better
- "tie" if both responses are equally good

IMPORTANT: Use the EXACT dimension names from the list above as JSON keys. Do not use generic names like "dimension1", "dimension2", etc."""

# Scenario mapping from classifier output to metrics.yaml scenario names
SCENARIO_MAPPING = {
    'open_question_answering': 'open_question',
    'general_explanation': 'explaining_general',
    'solving_math_exam_problem': 'solving_exam_question_with_math',
    'solving_general_exam_problem': 'solving_exam_question_without_math',
    'code_writing': 'code_writing',
    'text_translation': 'text_to_text_translation',
    'general_analysis': 'analyzing_general',
    'text_summarization': 'text_summarization',
    'category_identification': 'classification_identification',
    'title_generation': 'title_generation',
    'question_generation': 'question_generation',
    'reading_comprehension': 'reading_comprehension',
    'keywords_extraction': 'keywords_extraction',
    'information_extraction': 'information_extraction',
    'topic_modeling': 'topic_modeling',
    'ranking': 'ranking',
    'data_analysis': 'data_analysis',
    'language_polish': 'language_polishing',
    'fact_verification': 'verifying_fact',
    'writing_legal_document': 'writing_legal_document',
    'writing_social_media_post': 'writing_social_media_post',
    'seeking_medical_advice': 'seeking_medical_advice',
    'seeking_legal_advice': 'seeking_advice',
    'inquire_safe_experimental_practices': 'safe_experimental_practices',
    'instructional_rewriting': 'instructional_rewriting',
    'literary_appreciation': 'literary_appreciation',
    'creative_writing': 'creative_writing',
    'chitchat': 'chitchat',
    'roleplay': 'roleplay',
    'brainstorming': 'brainstorming',
    'seeking_advice': 'seeking_advice',
    'planning': 'planning',
    'value_judgement': 'value_judgement',
    'recommendation': 'recommendation'
}

def map_scenario_to_metrics(scenario: str) -> str:
    """Map classifier scenario output to metrics.yaml scenario name"""
    return SCENARIO_MAPPING.get(scenario, scenario)

def parse_args():
    parser = argparse.ArgumentParser(description="Mixture of Experts Candidate Profiling - Updated 500-Instance Winners Version")
    parser.add_argument('--output', type=str, default='moe_candidate_profiling_500_winners_updated_results.json', help='Output file for profiling results')
    parser.add_argument('--api_key_1', type=str, required=True, help='First API key')
    parser.add_argument('--api_key_2', type=str, required=True, help='Second API key')
    parser.add_argument('--base_url', type=str, default='https://api.example.com/v1', help='Base URL for API calls')
    parser.add_argument('--seeds_path', type=str, default='MOE-Judge/seeds.json', help='Path to seeds.json')
    parser.add_argument('--metrics_path', type=str, default='MOE-Judge/metrics.yaml', help='Path to metrics.yaml')
    parser.add_argument('--nums', type=int, default=500, help='Number of instances to process (default: 500, use 0 for all instances)')
    parser.add_argument('--max_retries', type=int, default=3, help='Maximum number of retries for failed API calls')
    parser.add_argument('--workers', type=int, default=19, help='Number of parallel workers (default: 19 for 20 RPM with 2 API keys)')
    parser.add_argument('--debug', action='store_true', help='Enable debug logging')
    return parser.parse_args()

def load_seeds_data(path: str, nums: int = 500, start_index: int = 0) -> List[Dict]:
    """Load seeds data with optional chunking for large files and resume support"""
    # Try different encodings to handle Unicode issues
    encodings = ['utf-8', 'utf-8-sig', 'cp1252', 'latin-1']
    data = None
    
    for encoding in encodings:
        try:
            with open(path, 'r', encoding=encoding) as f:
                data = json.load(f)
                break
        except UnicodeDecodeError:
            continue
    
    if data is None:
        print("Error: Could not decode seeds.json with any encoding")
        return []
    
    # Apply start_index for resume functionality
    if start_index > 0:
        if start_index >= len(data):
            print(f"Warning: start_index ({start_index}) is >= total instances ({len(data)}). Starting from beginning.")
            start_index = 0
        else:
            print(f"Resuming from instance {start_index} (0-indexed)")
            data = data[start_index:]
    
    # Apply nums limit if specified
    if nums > 0:
        data = data[:nums]
    
    return data

def parse_evaluations(output: str, expected_dimensions: List[str]) -> Dict[str, str]:
    """Parse API response to extract evaluations in 1/2/tie format"""
    if not output:
        logging.warning("Empty output received from API")
        return {}
    
    if output.strip() == "":
        logging.warning("Empty string output received from API")
        return {}
    
    logging.debug(f"Raw API output: {output[:500]}...")  # Log more for debugging
    logging.debug(f"Expected dimensions: {expected_dimensions}")
    
    try:
        # First, try to extract JSON from markdown code blocks
        match = re.search(r"```json\n?(.*?)```", output, re.DOTALL)
        if match:
            json_str = match.group(1).strip()
            logging.debug(f"Extracted JSON from code block: {json_str}")
        else:
            # Try to find the last complete JSON object in the response
            # This handles cases where there's content after the JSON
            json_objects = []
            start = 0
            while True:
                start = output.find('{', start)
                if start == -1:
                    break
                
                # Find the matching closing brace
                brace_count = 0
                end = start
                for i in range(start, len(output)):
                    if output[i] == '{':
                        brace_count += 1
                    elif output[i] == '}':
                        brace_count -= 1
                        if brace_count == 0:
                            end = i + 1
                            break
                
                if end > start:
                    json_obj = output[start:end]
                    json_objects.append(json_obj)
                    start = end
            
            if json_objects:
                # Use the last (most complete) JSON object
                json_str = json_objects[-1]
                logging.debug(f"Extracted JSON from response (last object): {json_str}")
            else:
                # Fallback: try to find any JSON-like structure
                start = output.find('{')
                end = output.rfind('}') + 1
                if start != -1 and end != -1 and end > start:
                    json_str = output[start:end]
                    logging.debug(f"Extracted JSON from response (fallback): {json_str}")
                else:
                    # If no JSON found, try to parse the entire response
                    json_str = output.strip()
                    logging.debug(f"Using entire response as JSON: {json_str}")
        
        # Clean up the JSON string
        json_str = json_str.replace("'", '"')  # Replace single quotes with double quotes
        
        # Try to parse the JSON
        data = json.loads(json_str)
        logging.debug(f"Successfully parsed JSON: {data}")
        
        # Validate that we got a dictionary
        if not isinstance(data, dict):
            logging.error(f"Parsed JSON is not a dictionary: {type(data)}")
            return {}
        
        evaluations = {}
        valid_evaluations = 0
        found_dimensions = []
        
        # Create a case-insensitive mapping for API response keys
        # Normalize: lowercase and replace spaces with underscores
        data_lower = {}
        for k, v in data.items():
            normalized_key = k.lower().replace(' ', '_')
            data_lower[normalized_key] = v
        
        logging.debug(f"Normalized API response keys: {list(data_lower.keys())}")
        
        # First, try to match exact dimension names
        for dimension in expected_dimensions:
            # Normalize expected dimension name to match API response format
            normalized_dim = dimension.lower().replace(' ', '_')
            val = data_lower.get(normalized_dim)
            
            if val in ['1', '2', 'tie']:
                evaluations[dimension] = val
                valid_evaluations += 1
                found_dimensions.append(dimension)
                logging.debug(f"Found exact match for '{dimension}': {val}")
            else:
                # Try alternative formats if the normalized version doesn't work
                alt_formats = [
                    dimension.lower(),  # just lowercase
                    dimension.lower().replace(' ', ''),  # lowercase, no spaces
                    dimension.upper(),  # uppercase
                    dimension.upper().replace(' ', '_'),  # uppercase with underscores
                ]
                
                found_val = None
                for alt_format in alt_formats:
                    if alt_format in data_lower:
                        found_val = data_lower[alt_format]
                        logging.debug(f"Found evaluation for {dimension} using alternative format '{alt_format}': {found_val}")
                        break
                
                if found_val in ['1', '2', 'tie']:
                    evaluations[dimension] = found_val
                    valid_evaluations += 1
                    found_dimensions.append(dimension)
                else:
                    evaluations[dimension] = None
                    logging.debug(f"No valid evaluation found for dimension '{dimension}' (normalized: '{normalized_dim}'). Available keys: {list(data_lower.keys())}")
        
        # If we still didn't find matches, try to find partial matches
        if valid_evaluations == 0:
            logging.warning("No exact dimension matches found, trying partial matches...")
            for key, value in data.items():
                if value in ['1', '2', 'tie']:
                    # Try to find a matching dimension
                    for dimension in expected_dimensions:
                        if dimension.lower() in key.lower() or key.lower() in dimension.lower():
                            if evaluations[dimension] is None:  # Only update if not already set
                                evaluations[dimension] = value
                                valid_evaluations += 1
                                found_dimensions.append(dimension)
                                logging.debug(f"Found partial match: '{key}' -> '{dimension}': {value}")
                                break
        
        # If still no matches, try generic dimension keys (dimension1, dimension2, etc.)
        if valid_evaluations == 0:
            logging.warning("No partial matches found, trying generic dimension keys...")
            for i, dimension in enumerate(expected_dimensions):
                generic_key = f"dimension{i+1}"
                val = data.get(generic_key)
                if val in ['1', '2', 'tie']:
                    evaluations[dimension] = val
                    valid_evaluations += 1
                    found_dimensions.append(dimension)
                    logging.debug(f"Found generic match: '{generic_key}' -> '{dimension}': {val}")
                else:
                    logging.debug(f"No generic match for '{dimension}' (key '{generic_key}'): {val}")
                    evaluations[dimension] = None
        
        # Log what we found
        logging.info(f"Found {valid_evaluations}/{len(expected_dimensions)} valid evaluations")
        logging.debug(f"Found dimensions: {found_dimensions}")
        logging.debug(f"All data keys: {list(data.keys())}")
        
        # Check if we got enough valid evaluations (at least 80%)
        success_rate = valid_evaluations / len(expected_dimensions) if expected_dimensions else 0
        if success_rate < 0.8:
            logging.warning(f"Low success rate in evaluations: {valid_evaluations}/{len(expected_dimensions)} ({success_rate:.1%})")
            logging.warning(f"Available keys in response: {list(data.keys())}")
            logging.warning(f"Expected dimensions: {expected_dimensions}")
        
        return evaluations
        
    except json.JSONDecodeError as e:
        logging.error(f"JSON decode error: {e}")
        logging.error(f"Attempted to parse: {json_str if 'json_str' in locals() else 'N/A'}")
        
        # Try to extract JSON from responses with <think> tags
        if '<think>' in output:
            logging.info("Detected <think> tags, attempting to extract JSON...")
            # Find the JSON after the </think> tag
            think_end = output.find('</think>')
            if think_end != -1:
                after_think = output[think_end + 8:].strip()
                # Try to find JSON in the content after </think>
                start = after_think.find('{')
                end = after_think.rfind('}') + 1
                if start != -1 and end != -1 and end > start:
                    try:
                        json_str = after_think[start:end]
                        data = json.loads(json_str)
                        logging.info("Successfully extracted JSON from after </think> tag")
                        
                        # Process the extracted data
                        evaluations = {}
                        valid_evaluations = 0
                        for dimension in expected_dimensions:
                            val = data.get(dimension)
                            if val in ['1', '2', 'tie']:
                                evaluations[dimension] = val
                                valid_evaluations += 1
                            else:
                                evaluations[dimension] = None
                        
                        logging.info(f"Extracted {valid_evaluations}/{len(expected_dimensions)} evaluations from think response")
                        return evaluations
                    except json.JSONDecodeError:
                        logging.error("Failed to parse JSON after </think> tag")
        
        return {}
    except Exception as e:
        logging.error(f"Failed to parse JSON evaluations: {e}")
        logging.error(f"Raw output was: {output}")
        return {}

async def call_expert_api_async(session, messages: List[Dict], model: str, api_key_manager: DualAPIKeyManager, base_url: str, max_retries: int = 3) -> Dict:
    """Make a single API call to expert model with global rate limiting and retry for all errors"""
    
    for attempt in range(max_retries):
        try:
            # Check global rate limit before making request
            await rate_limit_handler.check_and_wait_if_rate_limited()
            
            # Get API key for this request
            api_key = await api_key_manager.get_next_api_key()
            
            headers = {
                "Authorization": f"Bearer {api_key}",
                "Content-Type": "application/json"
            }
            payload = {
                "model": model,
                "messages": messages,
                "temperature": 0.0,  # More deterministic
                "max_tokens": 12000,  # Increased for very complex multi-dimension evaluations
                "response_format": {"type": "json_object"}
            }
            
            logging.debug(f"Making API call to {model} with payload: {json.dumps(payload, indent=2)}")
            
            async with session.post(
                f"{base_url}/chat/completions",
                headers=headers,
                json=payload,
                timeout=aiohttp.ClientTimeout(total=60)
            ) as response:
                # Handle 429 rate limit error - set global rate limit
                if response.status == 429:
                    await rate_limit_handler.set_rate_limited(60)
                    logging.warning(f"Rate limit hit (429) for model {model}. Global rate limit set. Retry {attempt + 1}/{max_retries}")
                    continue
                
                # Handle 500/504 errors with retry (no fallback)
                if response.status in [500, 504]:
                    logging.warning(f"Server error {response.status} for model {model}. Retry {attempt + 1}/{max_retries}")
                    if attempt < max_retries - 1:
                        await asyncio.sleep(2 ** attempt)  # Exponential backoff
                    continue
                
                # Log response status and headers for debugging
                logging.debug(f"API response status: {response.status}")
                logging.debug(f"API response headers: {dict(response.headers)}")
                
                response.raise_for_status()
                response_data = await response.json()
                
                # Log the response structure for debugging
                logging.debug(f"API response data keys: {list(response_data.keys())}")
                
                if "choices" not in response_data or not response_data["choices"]:
                    logging.error(f"Invalid response structure from {model}: {response_data}")
                    return {
                        "success": False,
                        "error": "Invalid response structure - no choices found",
                        "content": None,
                        "usage": None
                    }
                
                content = response_data["choices"][0]["message"]["content"]
                logging.debug(f"API response content: {content[:500]}...")  # Log more content for debugging
                
                # Check for truncation indicators
                if content.endswith('...') or len(content) > 11500:  # Close to max_tokens limit
                    logging.warning(f"Response may be truncated for model {model}. Content length: {len(content)}")
                
                # Validate that we got a non-empty response
                if not content or content.strip() == "":
                    logging.warning(f"Empty response from {model}, retrying... (attempt {attempt + 1}/{max_retries})")
                    if attempt == max_retries - 1:
                        return {
                            "success": False,
                            "error": "Empty response after all retries",
                            "content": None,
                            "usage": None
                        }
                    continue
                
                return {
                    "success": True,
                    "content": content,
                    "usage": response_data.get("usage")
                }
                
        except aiohttp.ClientResponseError as e:
            logging.error(f"HTTP error {e.status} for model {model}: {e.message}")
            
            # Handle 429 rate limit
            if e.status == 429:
                await rate_limit_handler.set_rate_limited(60)
                logging.warning(f"Rate limit hit (429) for model {model}. Global rate limit set. Retry {attempt + 1}/{max_retries}")
                continue
            
            # Handle 500/504 errors with retry (no fallback)
            if e.status in [500, 504]:
                logging.warning(f"Server error {e.status} for model {model}. Retry {attempt + 1}/{max_retries}")
                if attempt < max_retries - 1:
                    await asyncio.sleep(2 ** attempt)  # Exponential backoff
                continue
            
            if attempt == max_retries - 1:
                return {
                    "success": False,
                    "error": f"HTTP {e.status}: {e.message}",
                    "content": None,
                    "usage": None
                }
        except Exception as e:
            # Check if it's a 429 error (rate limit) - set global rate limit
            if "429" in str(e) or "rate limit" in str(e).lower():
                await rate_limit_handler.set_rate_limited(60)
                logging.warning(f"Rate limit hit (429) for model {model}. Global rate limit set. Retry {attempt + 1}/{max_retries}")
                continue
            
            logging.warning(f"API call attempt {attempt + 1} failed for {model}: {e}")
            if attempt == max_retries - 1:
                return {
                    "success": False,
                    "error": str(e),
                    "content": None,
                    "usage": None
                }
            await asyncio.sleep(2 ** attempt)  # Exponential backoff for other errors
    
    return {
        "success": False,
        "error": "Max retries exceeded",
        "content": None,
        "usage": None
    }

async def classify_scenario_async(session, instance: Dict, api_key_manager: DualAPIKeyManager, base_url: str) -> str:
    """Classify the scenario of an instance using the scenario classifier"""
    prompt = instance.get('prompt', '')
    response_a = instance.get('response_a', '')
    response_b = instance.get('response_b', '')
    classification_payload = {
        "model": "provider-6/o3-pro",
        "messages": [
            {
                "role": "system",
                "content": """You are an AI assistant specialized in scenario classification.\nYour goal is to read each user message and assign it to exactly one of the predefined scenarios below. Each scenario belongs to a broader human‑need category. If none fit, label it \"Other.\"\n\nAvailable Scenarios:\n1. Safety Needs\n  • writing legal document – Drafting, reviewing or finalizing legal texts (contracts, wills, deeds, agreements).\n  • seeking medical advice – Asking for health information, guidance on symptoms, conditions or treatments.\n  • inquire safe experimental practices – Requesting guidelines to conduct experiments safely, minimize risks, and follow regulations.\n\n2. Social Needs\n  • chitchat – Casual, informal conversation without problem‑solving or detailed information exchange.\n  • roleplay – Interactive dialogue where participants assume roles to simulate scenarios.\n  • writing social media post – Crafting messages for platforms like Facebook, Twitter, or Instagram.\n\n3. Cognitive Needs\n  • reasoning – Logical, systematic thinking to solve problems.\n  • solving general exam problem – Answering non‑mathematical exam questions via reasoning.\n  • solving math exam problem – Step‑by‑step application of mathematical concepts to solve problems.\n  • instructional rewriting – Revising text for clarity, accuracy, or coherence while preserving intent.\n  • code writing – Creating, editing or debugging code for problem‑solving or automation.\n  • text translation – Converting text between languages, preserving meaning and tone.\n  • general explanation – Providing clear, detailed explanations to inform and educate.\n  • functional writing – Writing precise instructions or information for a specific purpose.\n  • fact verification – Confirming the accuracy and reliability of information.\n  • general analysis – Examining or interpreting data broadly to deepen understanding.\n  • text summarization – Condensing text while retaining key information.\n  • category identification – Classifying items into predefined categories.\n  • title generation – Creating an appropriate, compelling title.\n  • question generation – Forming questions from given content.\n  • reading comprehension – Answering questions from information in a passage.\n  • keywords extraction – Identifying the most important words or phrases.\n  • information extraction – Pulling out specific information categories as requested.\n  • topic modeling – Identifying high‑level themes or subjects in a text.\n  • ranking – Sorting items based on specified criteria.\n  • data analysis – Inspecting, cleansing, transforming, and modeling data to extract insights.\n\n4. Aesthetic Needs\n  • literary appreciation – Analyzing and evaluating literary works for artistic qualities, themes, and style.\n  • language polish – Refining text for clarity, readability, and overall quality.\n  • recommendation – Suggesting products, services, or actions based on preferences.\n  • creative writing – Crafting imaginative stories, poems, or other original expressions.\n\n5. Self‑Actualization Needs\n  • seeking advice – Requesting guidance or solutions for a specific problem.\n  • planning – Organizing or strategizing for future actions or events.\n  • brainstorming – Generating a wide range of ideas collaboratively.\n  • open question answering – Responding to broad, open‑ended queries.\n  • value judgement – Evaluating or making decisions based on personal beliefs or ethics.\n\n6. Other\n  • Other – Any message that does not fit any of the above scenarios.\n\nClassification Instructions:\n1. Read the user's message in full.\n2. Compare its intent and content against each scenario definition.\n3. Select exactly one scenario whose definition best matches the message.\n4. If none match, choose \"Other.\"\n5. Output only the scenario name (e.g., solving math exam problem) with no additional commentary.\n\nExample:\nUser: \"Can you help me rewrite this paragraph so it's clearer?\"\nAssistant: instructional rewriting"""
            },
            {
                "role": "user",
                "content": json.dumps({"prompt": prompt, "response_a": response_a, "response_b": response_b})
            }
        ],
        "temperature": 0.1
    }
    headers = {
        "Authorization": f"Bearer {api_key_manager.get_api_key_for_worker(0)}", # Use a specific API key for classification
        "Content-Type": "application/json"
    }
    try:
        # Check global rate limit before making request
        await rate_limit_handler.check_and_wait_if_rate_limited()
        
        async with session.post(
            f"{base_url}/chat/completions",
            headers=headers,
            json=classification_payload,
            timeout=aiohttp.ClientTimeout(total=60)
        ) as response:
            # Handle 429 rate limit error - set global rate limit
            if response.status == 429:
                await rate_limit_handler.set_rate_limited(60)
                logging.warning(f"Rate limit hit (429) during scenario classification. Global rate limit set.")
                return "default"
            
            response.raise_for_status()
            response_data = await response.json()
            scenario = response_data["choices"][0]["message"]["content"].strip()
            scenario = scenario.lower().replace(' ', '_')
            return scenario
    except Exception as e:
        # Check if it's a 429 error (rate limit) - set global rate limit
        if "429" in str(e) or "rate limit" in str(e).lower():
            await rate_limit_handler.set_rate_limited(60)
            logging.warning(f"Rate limit hit (429) during scenario classification. Global rate limit set.")
            return "default"
        
        logging.warning(f"Failed to classify scenario: {e}")
        return "default"

def get_user_prompt(instance: Dict, metrics_config: Dict[str, Any], scenario: str) -> str:
    prompt = instance.get('prompt', '')
    response_a = instance.get('response_a', '')
    response_b = instance.get('response_b', '')
    scenario_description = get_scenario_description(metrics_config, scenario)
    dimensions_with_descriptions = get_scenario_dimensions_with_descriptions(metrics_config, scenario)
    if not dimensions_with_descriptions:
        scenario_description = f"Scenario: {scenario}"
        dimensions_with_descriptions = [{"name": "Overall winner", "description": "Overall quality comparison"}]
    dimensions_text = ""
    for dim in dimensions_with_descriptions:
        dimensions_text += f"- {dim['name']}: {dim['description']}\n"
    return USER_PROMPT_TEMPLATE.format(
        scenario_description=scenario_description,
        dimensions_with_descriptions=dimensions_text.strip(),
        prompt=prompt,
        response_a=response_a,
        response_b=response_b
    )

async def process_dimension_async(session, dimension: str, user_prompt: str, api_key_manager: DualAPIKeyManager, base_url: str, max_retries: int = 3) -> Dict:
    """Process a single dimension asynchronously"""
    best_model_name = DIMENSION_TO_MODEL_MAPPING.get(dimension)
    if not best_model_name:
        logging.warning(f"No expert model found for dimension: {dimension}")
        return {
            "dimension": dimension,
            "success": False,
            "error": "No expert model found",
            "evaluation": None,
            "model": None
        }
    
    messages = [
        {"role": "system", "content": SYSTEM_PROMPT_TEMPLATE},
        {"role": "user", "content": user_prompt}
    ]
    
    api_result = await call_expert_api_async(session, messages, best_model_name, api_key_manager, base_url, max_retries)
    
    if api_result["success"]:
        evaluations = parse_evaluations(api_result["content"], [dimension])
        evaluation = evaluations.get(dimension)
        
        # Validate that we got a valid evaluation
        if evaluation in ['1', '2', 'tie']:
            return {
                "dimension": dimension,
                "success": True,
                "evaluation": evaluation,
                "model": best_model_name,
                "api_response": api_result["content"],
                "usage": api_result.get("usage")
            }
        else:
            logging.warning(f"Invalid evaluation for dimension {dimension}: {evaluation}")
            return {
                "dimension": dimension,
                "success": False,
                "error": f"Invalid evaluation: {evaluation}",
                "evaluation": None,
                "model": best_model_name,
                "api_response": api_result["content"],
                "usage": api_result.get("usage")
            }
    else:
        return {
            "dimension": dimension,
            "success": False,
            "error": api_result.get("error"),
            "evaluation": None,
            "model": best_model_name,
            "api_response": None,
            "usage": None
        }

async def process_instances_moe_parallel(instances: List[Dict], metrics_config: Dict[str, Any], 
                                       api_key_manager: DualAPIKeyManager, base_url: str, max_retries: int = 3, 
                                       workers: int = 10) -> List[Dict]:
    """Process instances in parallel with MoE approach - dimension-level processing"""
    per_instance_results = []
    
    # Create aiohttp session with SSL context to handle certificate issues
    import ssl
    ssl_context = ssl.create_default_context()
    ssl_context.check_hostname = False
    ssl_context.verify_mode = ssl.CERT_NONE
    
    connector = aiohttp.TCPConnector(limit=workers, limit_per_host=workers, ssl=ssl_context)
    timeout = aiohttp.ClientTimeout(total=60)
    
    async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
        # Process instances one by one, but dimensions in parallel within each instance
        for i, instance in enumerate(instances):
            logging.info(f"Processing instance {i+1}/{len(instances)}: {instance.get('id', 'unknown')}")
            
            # Use original scenario from seeds.json if available, otherwise classify
            original_scenario = instance.get('scenario')
            if original_scenario:
                scenario = original_scenario
                mapped_scenario = map_scenario_to_metrics(scenario)
                logging.info(f"Instance {instance.get('id', 'unknown')} using original scenario: {scenario} -> mapped to: {mapped_scenario}")
            else:
                # Fallback to classification only if no scenario provided
                scenario = await classify_scenario_async(session, instance, api_key_manager, base_url)
                mapped_scenario = map_scenario_to_metrics(scenario)
                logging.info(f"Instance {instance.get('id', 'unknown')} classified as: {scenario} -> mapped to: {mapped_scenario}")
            
            dimensions_with_descriptions = get_scenario_dimensions_with_descriptions(metrics_config, mapped_scenario)
            expected_dimensions = [dim['name'] for dim in dimensions_with_descriptions]
            
            if not expected_dimensions:
                logging.warning(f"No dimensions found for scenario: {mapped_scenario} (original: {scenario})")
                per_instance_results.append({
                    "id": instance.get('id', 'unknown'),
                    "scenario": mapped_scenario,
                    "original_scenario": scenario,
                    "status": "failed",
                    "error": "No dimensions found for scenario"
                })
                continue
            
            user_prompt = get_user_prompt(instance, metrics_config, mapped_scenario)
            
            # Process dimensions in batches of 10 (rate limit management)
            dimension_results = []
            failed_dimensions = []
            
            # Process dimensions in batches
            for batch_start in range(0, len(expected_dimensions), workers):
                batch_end = min(batch_start + workers, len(expected_dimensions))
                batch_dimensions = expected_dimensions[batch_start:batch_end]
                
                logging.info(f"Processing dimensions {batch_start+1}-{batch_end} of {len(expected_dimensions)} for instance {instance.get('id', 'unknown')}")
                
                # Retry batch if success rate is too low
                max_batch_retries = 2
                batch_success = False
                
                for batch_attempt in range(max_batch_retries):
                    # Create tasks for this batch of dimensions
                    dimension_tasks = [
                        process_dimension_async(session, dimension, user_prompt, api_key_manager, base_url, max_retries)
                        for dimension in batch_dimensions
                    ]
                    
                    # Process batch in parallel
                    batch_results = await asyncio.gather(*dimension_tasks, return_exceptions=True)
                    
                    # Process results and calculate success rate
                    successful_results = []
                    batch_failed_dimensions = []
                    
                    for j, result in enumerate(batch_results):
                        if isinstance(result, Exception):
                            dimension = batch_dimensions[j]
                            logging.error(f"Exception processing dimension {dimension}: {result}")
                            batch_failed_dimensions.append({
                                "dimension": dimension,
                                "error": str(result)
                            })
                        else:
                            successful_results.append(result)
                            if not result["success"]:
                                batch_failed_dimensions.append(result)
                    
                    # Calculate batch success rate
                    batch_success_rate = len(successful_results) / len(batch_dimensions)
                    
                    if batch_success_rate >= 0.8:  # 80% success threshold
                        logging.info(f"Batch successful: {len(successful_results)}/{len(batch_dimensions)} dimensions ({batch_success_rate:.1%})")
                        dimension_results.extend(successful_results)
                        failed_dimensions.extend(batch_failed_dimensions)
                        batch_success = True
                        break
                    else:
                        logging.warning(f"Batch success rate too low: {len(successful_results)}/{len(batch_dimensions)} ({batch_success_rate:.1%}). Retrying batch... (attempt {batch_attempt + 1}/{max_batch_retries})")
                        if batch_attempt < max_batch_retries - 1:
                            await asyncio.sleep(5)  # Short delay before retry
                        else:
                            logging.error(f"Batch failed after {max_batch_retries} attempts. Adding to failed dimensions.")
                            dimension_results.extend(successful_results)
                            failed_dimensions.extend(batch_failed_dimensions)
                
                # Add delay between batches to respect rate limits
                if batch_end < len(expected_dimensions):
                    logging.info(f"Waiting 60 seconds before next batch of dimensions...")
                    await asyncio.sleep(60)  # 10 calls per minute = 60 seconds between batches
            
            # Retry failed dimensions
            if failed_dimensions:
                logging.info(f"Retrying {len(failed_dimensions)} failed dimensions for instance {instance.get('id', 'unknown')}")
                retry_tasks = []
                for failed in failed_dimensions:
                    dimension = failed["dimension"]
                    retry_tasks.append(process_dimension_async(session, dimension, user_prompt, api_key_manager, base_url, max_retries))
                
                retry_results = await asyncio.gather(*retry_tasks, return_exceptions=True)
                
                for j, result in enumerate(retry_results):
                    if isinstance(result, Exception):
                        dimension = failed_dimensions[j]["dimension"]
                        logging.error(f"Retry failed for dimension {dimension}: {result}")
                    else:
                        dimension_results.append(result)
            
            # Process final results
            expert_results = {}
            dimension_evaluations = {}
            ground_truth = instance.get('annotation', {})
            comparison_results = {}
            
            expert_wins_1 = 0
            expert_wins_2 = 0
            expert_ties = 0
            
            for result in dimension_results:
                dimension = result["dimension"]
                expert_results[dimension] = {
                    "model": result["model"],
                    "success": result["success"],
                    "error": result.get("error"),
                    "api_response": result.get("api_response"),
                    "usage": result.get("usage")
                }
                
                if result["success"]:
                    dimension_evaluations[dimension] = result["evaluation"]
                    evaluation = result["evaluation"]
                    
                    # Count wins for overall winner calculation
                    if evaluation == "1":
                        expert_wins_1 += 1
                    elif evaluation == "2":
                        expert_wins_2 += 1
                    elif evaluation == "tie":
                        expert_ties += 1
                else:
                    dimension_evaluations[dimension] = None
                
                # Compare with ground truth
                gt_eval = ground_truth.get(dimension, {}).get('winner') if ground_truth else None
                comparison_results[dimension] = {
                    "expert_evaluation": dimension_evaluations.get(dimension),
                    "ground_truth": gt_eval,
                    "correct": dimension_evaluations.get(dimension) == gt_eval if dimension_evaluations.get(dimension) and gt_eval else None,
                    "expert_model": result["model"],
                    "api_response": result.get("api_response")
                }
            
            # Determine overall winner
            if expert_wins_1 > expert_wins_2:
                overall_winner = "1"
            elif expert_wins_2 > expert_wins_1:
                overall_winner = "2"
            else:
                overall_winner = "tie"
            
            instance_result = {
                "id": instance.get('id', 'unknown'),
                "scenario": mapped_scenario,
                "original_scenario": scenario,
                "results": comparison_results,
                "overall_winner": overall_winner,
                "expert_wins_1": expert_wins_1,
                "expert_wins_2": expert_wins_2,
                "expert_ties": expert_ties,
                "winner": instance.get('winner', ''),
                "metadata": instance.get('metadata', ''),
                "model_a": instance.get('model_a', ''),
                "model_b": instance.get('model_b', ''),
                "expert_results": expert_results,
                "status": "ok"
            }
            
            per_instance_results.append(instance_result)
            
            # Log progress
            successful_dims = sum(1 for r in dimension_results if r["success"])
            total_dims = len(expected_dimensions)
            logging.info(f"Instance {instance.get('id', 'unknown')} completed: {successful_dims}/{total_dims} dimensions successful")
    
    return per_instance_results

def save_moe_results(per_instance_results: List[Dict], metrics_config: Dict[str, Any], out_path: str = 'moe_candidate_profiling_500_winners_updated_results.json'):
    """Save MoE profiling results using updated winner models and compute summary statistics"""
    dimension_stats = defaultdict(lambda: {"correct": 0, "total": 0, "accuracy": 0.0, "failed_calls": 0, "empty_responses": 0})
    
    # Track instance-level overall accuracy
    instance_overall_correct = 0
    instance_overall_total = 0
    
    for instance_result in per_instance_results:
        if instance_result.get("status") != "ok":
            continue
        
        # Track instance-level overall accuracy
        expert_overall = instance_result.get("overall_winner")
        ground_truth_overall = instance_result.get("winner")
        
        # Map ground truth format: model_a -> "1", model_b -> "2"
        if ground_truth_overall == "model_a":
            ground_truth_overall = "1"
        elif ground_truth_overall == "model_b":
            ground_truth_overall = "2"
        
        # Debug logging
        logging.debug(f"Instance {instance_result.get('id', 'unknown')}: Expert overall = '{expert_overall}', Ground truth overall = '{ground_truth_overall}'")
        
        if expert_overall and ground_truth_overall:
            instance_overall_total += 1
            if expert_overall == ground_truth_overall:
                instance_overall_correct += 1
                logging.debug(f"Instance {instance_result.get('id', 'unknown')}: CORRECT overall prediction")
            else:
                logging.debug(f"Instance {instance_result.get('id', 'unknown')}: WRONG overall prediction")
        else:
            logging.warning(f"Instance {instance_result.get('id', 'unknown')}: Missing overall winner data - Expert: '{expert_overall}', Ground truth: '{ground_truth_overall}'")
        
        for dimension, result in instance_result["results"].items():
            if result.get("correct") is not None:
                dimension_stats[dimension]["total"] += 1
                if result["correct"]:
                    dimension_stats[dimension]["correct"] += 1
            
            # Track API call issues
            expert_result = instance_result.get("expert_results", {}).get(dimension, {})
            if not expert_result.get("success", True):
                dimension_stats[dimension]["failed_calls"] += 1
                if "Empty response" in expert_result.get("error", ""):
                    dimension_stats[dimension]["empty_responses"] += 1
    
    # Calculate accuracy for each dimension
    for dimension, stats in dimension_stats.items():
        if stats["total"] > 0:
            stats["accuracy"] = stats["correct"] / stats["total"]
    
    # Calculate instance-level overall accuracy
    instance_overall_accuracy = instance_overall_correct / instance_overall_total if instance_overall_total > 0 else 0
    
    # Calculate dimension mean accuracy, excluding dimensions with failed API calls
    valid_dimensions = [stats for stats in dimension_stats.values() if stats.get("failed_calls", 0) == 0 and stats["total"] > 0]
    dimension_mean_accuracy = sum(stats["accuracy"] for stats in valid_dimensions) / len(valid_dimensions) if valid_dimensions else 0
    
    summary = {
        "dimension_stats": dict(dimension_stats),
        "overall_stats": {
            "total_instances": len(per_instance_results),
            "successful_instances": len([r for r in per_instance_results if r.get("status") == "ok"]),
            "failed_instances": len([r for r in per_instance_results if r.get("status") == "failed"]),
            "instance_overall_accuracy": instance_overall_accuracy,
            "instance_overall_correct": instance_overall_correct,
            "instance_overall_total": instance_overall_total,
            "dimension_mean_accuracy": dimension_mean_accuracy,
            "total_dimensions": len(dimension_stats),
            "valid_dimensions": len(valid_dimensions),
            "failed_dimensions": len([stats for stats in dimension_stats.values() if stats.get("failed_calls", 0) > 0])
        }
    }
    
    with open(out_path, 'w') as f:
        json.dump(per_instance_results, f, indent=2)
    
    # Save summary
    summary_path = out_path.replace('.json', '_summary.json')
    
    with open(summary_path, 'w') as f:
        json.dump(summary, f, indent=2)
    
    logging.info(f"MoE profiling results saved to {out_path} and {summary_path}")
    
    # Print summary table
    print("\nMixture of Experts Candidate Profiling Summary Table:")
    print("=" * 120)
    print(f"{'Dimension':<25} {'Accuracy':<10} {'Correct':<8} {'Total':<6} {'Failed':<7} {'Empty':<6} {'Expert Model':<40}")
    print("-" * 120)
    
    for dimension, stats in sorted(dimension_stats.items()):
        accuracy_pct = stats["accuracy"] * 100
        expert_model = DIMENSION_TO_MODEL_MAPPING.get(dimension, "Unknown")
        failed_calls = stats.get("failed_calls", 0)
        empty_responses = stats.get("empty_responses", 0)
        print(f"{dimension:<25} {accuracy_pct:>8.1f}% {stats['correct']:>8} {stats['total']:>6} {failed_calls:>7} {empty_responses:>6} {expert_model:<40}")
    
    print("-" * 120)
    instance_accuracy_pct = instance_overall_accuracy * 100
    dimension_mean_pct = summary["overall_stats"]["dimension_mean_accuracy"] * 100
    successful = summary["overall_stats"]["successful_instances"]
    total = summary["overall_stats"]["total_instances"]
    total_dimensions = summary["overall_stats"].get("total_dimensions", len(dimension_stats))
    valid_dimensions = summary["overall_stats"].get("valid_dimensions", len([stats for stats in dimension_stats.values() if stats.get("failed_calls", 0) == 0 and stats["total"] > 0]))
    failed_dimensions = summary["overall_stats"].get("failed_dimensions", len([stats for stats in dimension_stats.values() if stats.get("failed_calls", 0) > 0]))
    
    print(f"{'INSTANCE OVERALL':<25} {instance_accuracy_pct:>8.1f}% ({instance_overall_correct}/{instance_overall_total})")
    print(f"{'DIMENSION MEAN':<25} {dimension_mean_pct:>8.1f}% (excludes {failed_dimensions} failed API dimensions)")
    print(f"Success Rate: {successful}/{total} ({successful/total*100:.1f}%)")
    print(f"Dimensions: {valid_dimensions}/{total_dimensions} valid (excluded {failed_dimensions} with API failures)")
    
    # Print consistency warnings
    high_failure_dimensions = [dim for dim, stats in dimension_stats.items() if stats.get("failed_calls", 0) > 0]
    if high_failure_dimensions:
        print(f"\nWarning: Dimensions with API call failures: {', '.join(high_failure_dimensions)}")
        print(f"   These dimensions are excluded from the DIMENSION MEAN calculation")
    
    high_empty_dimensions = [dim for dim, stats in dimension_stats.items() if stats.get("empty_responses", 0) > 0]
    if high_empty_dimensions:
        print(f"Warning: Dimensions with empty responses: {', '.join(high_empty_dimensions)}")

def main():
    args = parse_args()
    
    # Setup logging with debug level if requested
    if args.debug:
        for handler in logging.root.handlers[:]:
            logging.root.removeHandler(handler)
        logging.basicConfig(
            level=logging.DEBUG,
            format='[%(levelname)s] %(message)s',
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler("moe_candidate_profiling_500_winners.log", mode='w')
            ]
        )
    else:
        setup_logging()
    
    try:
        # Initialize dual API key manager
        api_key_manager = DualAPIKeyManager(args.api_key_1, args.api_key_2)
        
        logging.info("Loading metrics configuration...")
        logging.info(f"MoE processing: {args.workers} workers with dual API keys")
        if args.debug:
            logging.debug("Debug mode enabled - detailed logging will be shown")
        metrics_config = load_metrics_config(args.metrics_path)
        
        logging.info("Loading seeds data...")
        instances = load_seeds_data(args.seeds_path, args.nums)
        if args.nums == 0:
            logging.info(f"Loaded all {len(instances)} instances from seeds.json")
        else:
            logging.info(f"Loaded {len(instances)} instances from seeds.json (limited to {args.nums})")
        
        # Debug: Show structure of first instance
        if instances and args.debug:
            first_instance = instances[0]
            logging.debug(f"First instance structure: {list(first_instance.keys())}")
            logging.debug(f"First instance winner field: '{first_instance.get('winner', 'NOT_FOUND')}'")
            logging.debug(f"First instance annotation: {first_instance.get('annotation', 'NOT_FOUND')}")
        
        logging.info(f"Processing {len(instances)} instances using Mixture of Experts with dual API keys...")
        
        # Run async processing
        per_instance_results = asyncio.run(
            process_instances_moe_parallel(
                instances, 
                metrics_config, 
                api_key_manager, 
                args.base_url, 
                args.max_retries, 
                args.workers
            )
        )
        
        # Save results
        save_moe_results(per_instance_results, metrics_config, args.output)
        
        logging.info("Mixture of Experts candidate profiling completed successfully!")
        
    except KeyboardInterrupt:
        logging.info("Interrupted by user.")
    except Exception as e:
        logging.error(f"Error during MoE profiling: {e}")
        raise
    finally:
        logging.shutdown()

if __name__ == "__main__":
    main() 