#!/usr/bin/env python3
"""
Comprehensive Baseline Evaluation Script with Valid-Cases-Only Metrics
Compares 5 baselines: OR-LLM-Agent, OptiMUS, Simple Zero-Shot, Chain-of-Experts, Hierarchical-Graph-Agent
All metrics calculated based on valid cases only (excluding Type 2)
Improved LLM extraction with 5-attempt majority voting
"""

import os
import sys
import csv
import json
import re
import time
import argparse
from pathlib import Path
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed
import threading
from dataclasses import dataclass
from typing import Optional, Dict, Any, List, Tuple
from collections import Counter
import multiprocessing as mp

# Setup paste API client for LLM justification
import openai
from openai import OpenAI

BASE_URL = "API_ENDPOINT_PLACEHOLDER/deepseek-v3-h200/v1"
MODEL_NAME =  "deepseek-ai/DeepSeek-V3"

#BASE_URL = "API_ENDPOINT_PLACEHOLDER/llama-4-scout-17b-16e-instruct/v1"
#MODEL_NAME = "meta-llama/Llama-4-Scout-17B-16E-Instruct"



def setup_paste_api_client():
    """Setup client using paste API configuration"""
    os.environ['RITS_API_KEY'] = 'RITS_API_PLACEHOLDER'
    api_key = os.environ.get("RITS_API_KEY")
    
    if not api_key:
        raise ValueError("Please set RITS_API_KEY environment variable")
    
    client = OpenAI(
        api_key="dummy",
        base_url=BASE_URL,
        default_headers={"RITS_API_KEY": api_key},
        timeout=300
    )
    return client

# Global client for multiprocessing
paste_client = None

def init_worker():
    """Initialize worker process with API client"""
    global paste_client
    paste_client = setup_paste_api_client()

def get_llm_response(prompt, max_retries=3):
    """Call LLM API with retry logic - 3 attempts per individual call"""
    global paste_client
    if paste_client is None:
        paste_client = setup_paste_api_client()
        
    for attempt in range(max_retries):
        try:
            response = paste_client.chat.completions.create(
                model=MODEL_NAME,
                messages=[{"role": "user", "content": prompt}],
                max_tokens=256,  # Reduced since we only need short responses
                temperature=0.0,
                top_p=0.9,
                frequency_penalty=0.0,
                presence_penalty=0.0,
                stream=False
            )
            
            if response.choices and response.choices[0].message.content:
                return response.choices[0].message.content
            else:
                print(f"WARNING: Empty response on attempt {attempt + 1}")
                
        except Exception as e:
            print(f"WARNING: API call failed on attempt {attempt + 1}: {e}")
            if attempt < max_retries - 1:
                time.sleep(min(2 ** attempt, 5))  # Cap at 5 seconds
            else:
                return None
    
    return None

@dataclass
class BaselineResult:
    database_name: str
    baseline_name: str
    ground_truth_value: Optional[str]
    baseline_value: Optional[str]
    has_result: bool
    case_type: int  # 1-5 classification
    case_description: str
    accuracy_match: Optional[bool]
    solver_consistency: str

@dataclass
class ComparisonMetrics:
    baseline_name: str
    valid_cases: int          # Cases with reliable ground truth (exclude Type 2)
    valid_successful_cases: int  # Successful cases within valid cases only
    correct_cases: int        # Correct cases (always within valid cases)
    success_rate: float       # valid_successful_cases / valid_cases
    accuracy_rate: float      # correct_cases / valid_cases
    accuracy_among_successful: float  # correct_cases / valid_successful_cases
    type_distribution: Dict[int, int]

def extract_baseline_result_llm_majority_vote(code_output_file: Path) -> Optional[str]:
    """Extract using LLM with 5-attempt majority voting"""
    
    try:
        with open(code_output_file, 'r', encoding='utf-8') as f:
            content = f.read()
    except Exception as e:
        print(f"Error reading file {code_output_file}: {e}")
        return None
    
    # Smart truncation for long content
    if len(content) > 8000:
        lines = content.split('\n')
        if len(lines) > 150:
            truncated_content = '\n'.join(lines[:80]) + '\n\n[... middle content truncated ...]\n\n' + '\n'.join(lines[-70:])
        else:
            truncated_content = content[:8000]
    else:
        truncated_content = content
    
    # Create improved prompt with very clear instructions
    prompt = f"""You are analyzing optimization solver output. Your task is to classify the final result into exactly ONE category.

CLASSIFICATION RULES:
1. ERROR: If the solver failed, crashed, timed out, or couldn't complete
   - Look for: "ERROR", "failed", "exception", "timeout", "crashed"
   - Examples: "ERROR: Hierarchical Graph Agent optimization failed"
             "Final Status: failed_incomplete_data"
             "Cannot proceed: parameters missing"
             "All code generation attempts failed"

2. INFEASIBLE: If the problem has no feasible solution
   - Look for: "infeasible", "no feasible solution"

3. UNBOUNDED: If the solution is unbounded
   - Look for: "unbounded", "unbounded solution"

4. NUMERICAL: If you find a final objective value (number)
   - Look for: "Optimal Objective Value: 123.45"
             "Optimal value: -67.8"
             "Best objective: 0.0"
   - Return the exact number (include negative sign if present)

CRITICAL: 
- ERROR takes priority over everything else
- Only return NUMERICAL if there are NO error indicators
- Look at the OVERALL STATUS, not just individual numbers

SOLVER OUTPUT:
{truncated_content}

RESPONSE FORMAT: Return ONLY one of these:
- ERROR (if failed/crashed/timeout)
- INFEASIBLE (if no solution exists)
- UNBOUNDED (if solution is unbounded)  
- [NUMBER] (exact numerical value if successful, e.g., "123.45", "-67.8", "0.0")

CLASSIFICATION:"""

    # Perform 5 independent LLM calls
    results = []
    for attempt in range(5):
        response = get_llm_response(prompt)
        if response:
            # Parse the response
            parsed_result = parse_llm_response(response.strip())
            results.append(parsed_result)
            print(f"  LLM attempt {attempt+1}: {parsed_result}")
        else:
            print(f"  LLM attempt {attempt+1}: FAILED")
    
    if not results:
        print(f"  All LLM attempts failed for {code_output_file.name}")
        return extract_with_regex_fallback(code_output_file)
    
    # Majority voting logic
    final_result = majority_vote(results)
    print(f"  Majority vote result: {final_result}")
    
    return final_result

def parse_llm_response(response: str) -> str:
    """Parse LLM response into standardized format"""
    response = response.strip().upper()
    
    # Check for exact matches first
    if response in ['ERROR', 'INFEASIBLE', 'UNBOUNDED']:
        return response
    
    # Try to extract number
    try:
        # Remove common prefixes/suffixes
        cleaned = response.replace('NUMERICAL:', '').replace('NUMBER:', '').strip()
        
        # Extract number pattern
        number_match = re.search(r'([+-]?[\d.]+(?:[eE][+-]?\d+)?)', cleaned)
        if number_match:
            number_str = number_match.group(1)
            # Validate it's a proper number
            float_val = float(number_str)
            return str(float_val)
    except (ValueError, AttributeError):
        pass
    
    # Check for status keywords in the response
    if any(word in response for word in ['ERROR', 'FAIL', 'CRASH', 'TIMEOUT', 'EXCEPTION']):
        return 'ERROR'
    elif 'INFEASIBLE' in response:
        return 'INFEASIBLE'
    elif 'UNBOUNDED' in response:
        return 'UNBOUNDED'
    
    # If unclear, default to ERROR (conservative approach)
    return 'ERROR'

def majority_vote(results: List[str]) -> str:
    """Perform majority voting on LLM results"""
    if not results:
        return 'ERROR'
    
    # Count occurrences
    counter = Counter(results)
    
    # Get the most common result
    most_common = counter.most_common(1)[0]
    most_common_result, most_common_count = most_common
    
    # Check if we have a clear majority (>= 3 out of 5)
    if most_common_count >= 3:
        return most_common_result
    
    # If no clear majority, apply tie-breaking rules
    # Rule 1: ERROR takes priority in ties
    if 'ERROR' in counter and counter['ERROR'] >= 2:
        return 'ERROR'
    
    # Rule 2: Non-numerical statuses take priority over numbers in ties
    non_numerical = [r for r in results if r not in counter or not is_numerical(r)]
    if non_numerical:
        non_num_counter = Counter(non_numerical)
        if non_num_counter:
            return non_num_counter.most_common(1)[0][0]
    
    # Rule 3: Fall back to most common result
    return most_common_result

def is_numerical(result: str) -> bool:
    """Check if result represents a numerical value"""
    try:
        float(result)
        return True
    except (ValueError, TypeError):
        return False

def extract_with_regex_fallback(code_output_file: Path) -> Optional[str]:
    """Regex fallback when LLM completely fails"""
    try:
        with open(code_output_file, 'r', encoding='utf-8') as f:
            content = f.read()
        
        content_lower = content.lower()
        
        # Check for obvious errors first
        error_indicators = [
            'error:', 'failed', 'exception', 'timeout', 'crashed',
            'hierarchical graph agent optimization failed',
            'final status: failed', 'final status: completed_with_errors',
            'cannot proceed', 'all .* attempts failed'
        ]
        
        for indicator in error_indicators:
            if indicator in content_lower:
                return 'ERROR'
        
        # Check for status
        if 'infeasible' in content_lower:
            return 'INFEASIBLE'
        elif 'unbounded' in content_lower:
            return 'UNBOUNDED'
        
        # Look for numerical values
        value_patterns = [
            r'Optimal\s+Objective\s+Value:\s*([+-]?[\d.]+(?:[eE][+-]?\d+)?)',
            r'Optimal\s+value:\s*([+-]?[\d.]+(?:[eE][+-]?\d+)?)',
            r'Best\s+objective\s*[:\s]*([+-]?[\d.]+(?:[eE][+-]?\d+)?)',
        ]
        
        for pattern in value_patterns:
            match = re.search(pattern, content, re.IGNORECASE)
            if match:
                try:
                    value = match.group(1).replace(',', '')
                    float_val = float(value)
                    return str(float_val)
                except ValueError:
                    continue
        
        return None
        
    except Exception as e:
        print(f"Regex fallback error for {code_output_file.name}: {e}")
        return None

def extract_ground_truth_from_solution(solution_file_path: Path) -> Tuple[Optional[str], str]:
    """Extract ground truth from problem_solution_description.md"""
    try:
        with open(solution_file_path, 'r', encoding='utf-8') as f:
            content = f.read()
        
        # Find section 8
        section_8_pattern = r'## 8\.\s*Cross-Solver Analysis and Final Recommendation(.*?)(?=## \d+|\Z)'
        section_8_match = re.search(section_8_pattern, content, re.DOTALL | re.IGNORECASE)
        
        if not section_8_match:
            return None, "Section 8 not found"
        
        section_8_content = section_8_match.group(1)
        
        # Extract solver results table
        table_pattern = r'\| Solver.*?\|(.*?)(?=\n\n|\n###|\Z)'
        table_match = re.search(table_pattern, section_8_content, re.DOTALL)
        
        if not table_match:
            return None, "Solver table not found"
        
        table_content = table_match.group(0)
        
        # Parse solver results
        solver_results = []
        lines = table_content.split('\n')
        for line in lines[2:]:  # Skip header and separator
            if '|' in line and line.strip():
                parts = [p.strip() for p in line.split('|')[1:-1]]
                if len(parts) >= 3:
                    solver_name = parts[0]
                    status = parts[1]
                    optimal_value = parts[2]
                    solver_results.append((solver_name, status, optimal_value))
        
        # Analyze consistency and find ground truth
        statuses = [result[1] for result in solver_results]
        values = [result[2] for result in solver_results]
        
        status_counts = {}
        value_counts = {}
        
        for status in statuses:
            status_counts[status] = status_counts.get(status, 0) + 1
        
        for value in values:
            if value.upper() not in ['N/A', 'ERROR', '']:
                value_counts[value] = value_counts.get(value, 0) + 1
        
        # Determine ground truth
        ground_truth = None
        consistency_desc = ""
        
        most_common_status = max(status_counts.items(), key=lambda x: x[1]) if status_counts else (None, 0)
        
        if most_common_status[1] >= 2:
            if most_common_status[0].upper() == 'OPTIMAL':
                if value_counts:
                    most_common_value = max(value_counts.items(), key=lambda x: x[1])
                    if most_common_value[1] >= 2:
                        ground_truth = most_common_value[0]
                        consistency_desc = f"OPTIMAL: {ground_truth}"
                    else:
                        consistency_desc = "OPTIMAL but inconsistent values"
                else:
                    consistency_desc = "OPTIMAL but no values found"
            elif most_common_status[0].upper() in ['INFEASIBLE', 'UNBOUNDED']:
                ground_truth = most_common_status[0].upper()
                consistency_desc = ground_truth
            else:
                consistency_desc = f"Consistent status: {most_common_status[0]}"
        else:
            consistency_desc = "No consensus among solvers"
        
        return ground_truth, consistency_desc
        
    except Exception as e:
        return None, f"Error parsing: {str(e)}"

def classify_case_type(has_ground_truth: bool, has_baseline_result: bool, 
                      ground_truth: Optional[str], baseline_result: Optional[str], 
                      solver_consistency: str) -> Tuple[int, str]:
    """Classify the case into one of 5 types"""
    
    if not has_ground_truth or "no consensus" in solver_consistency.lower():
        return 2, f"No unified ground truth ({solver_consistency})"
    
    if not has_baseline_result:
        return 5, "Baseline failed to generate result"
    
    # Type 3 or 4: Compare results
    if ground_truth and baseline_result:
        gt_normalized = str(ground_truth).strip().upper()
        baseline_normalized = str(baseline_result).strip().upper()
        
        # First check for exact string match (handles INFEASIBLE, UNBOUNDED, ERROR cases)
        if gt_normalized == baseline_normalized:
            return 3, "Baseline matches ground truth (SUCCESS)"
        
        # Then try numerical comparison with tolerance
        try:
            gt_float = float(gt_normalized)
            baseline_float = float(baseline_normalized)
            
            # Use relative tolerance for large numbers, absolute for small numbers
            if abs(gt_float) > 1e-6:
                relative_error = abs((gt_float - baseline_float) / gt_float)
                if relative_error < 1e-6:  # 0.0001% relative tolerance
                    return 3, "Baseline matches ground truth (SUCCESS)"
            else:
                if abs(gt_float - baseline_float) < 1e-6:  # Absolute tolerance for small numbers
                    return 3, "Baseline matches ground truth (SUCCESS)"
            
            return 4, f"Baseline differs from ground truth (GT: {ground_truth}, Baseline: {baseline_result})"
            
        except ValueError:
            # If conversion to float fails, we already checked string equality above
            return 4, f"Baseline differs from ground truth (GT: {ground_truth}, Baseline: {baseline_result})"
    
    return 2, "Unable to classify due to missing data"

def evaluate_single_database(database_info: Tuple[str, Path, List[Tuple[str, Path]]]) -> List[BaselineResult]:
    """Evaluate a single database across all baselines"""
    
    database_name, synthetic_path, baseline_paths = database_info
    
    print(f"Evaluating {database_name}...")
    
    # Extract ground truth
    solution_file = synthetic_path / "problem_solution_description.md"
    has_synthetic_data = solution_file.exists()
    
    ground_truth_value = None
    solver_consistency = ""
    
    if has_synthetic_data:
        ground_truth_value, solver_consistency = extract_ground_truth_from_solution(solution_file)
    else:
        solver_consistency = "No synthetic data file found"
    
    has_ground_truth = ground_truth_value is not None
    
    # Evaluate each baseline using improved LLM majority voting
    results = []
    
    for baseline_name, baseline_path in baseline_paths:
        print(f"  Processing {baseline_name}...")
        
        code_output_file = baseline_path / database_name / "code_output.txt"
        has_baseline_result = code_output_file.exists()
        
        if has_baseline_result:
            # Use improved LLM majority voting
            baseline_value = extract_baseline_result_llm_majority_vote(code_output_file)
            has_baseline_result = baseline_value is not None
        else:
            baseline_value = None
            has_baseline_result = False
        
        # Classify case type
        case_type, case_description = classify_case_type(
            has_ground_truth, has_baseline_result, ground_truth_value, baseline_value, solver_consistency
        )
        
        # Determine accuracy match
        accuracy_match = None
        if case_type == 3:
            accuracy_match = True
        elif case_type == 4:
            accuracy_match = False
        
        result = BaselineResult(
            database_name=database_name,
            baseline_name=baseline_name,
            ground_truth_value=ground_truth_value,
            baseline_value=baseline_value,
            has_result=has_baseline_result,
            case_type=case_type,
            case_description=case_description,
            accuracy_match=accuracy_match,
            solver_consistency=solver_consistency
        )
        
        results.append(result)
    
    return results

def find_common_databases(synthetic_dir: Path, baseline_dirs: Dict[str, Path]) -> List[Tuple[str, Path, List[Tuple[str, Path]]]]:
    """Find databases that exist in synthetic dir and all baseline dirs"""
    
    synthetic_dbs = {d.name for d in synthetic_dir.iterdir() if d.is_dir()}
    
    # Find intersection across all baseline directories
    common_dbs = synthetic_dbs
    for baseline_name, baseline_dir in baseline_dirs.items():
        baseline_dbs = {d.name for d in baseline_dir.iterdir() if d.is_dir()}
        common_dbs = common_dbs.intersection(baseline_dbs)
    
    database_info = []
    for db_name in sorted(common_dbs):
        synthetic_path = synthetic_dir / db_name
        baseline_paths = [(name, path) for name, path in baseline_dirs.items()]
        database_info.append((db_name, synthetic_path, baseline_paths))
    
    return database_info

def calculate_comparison_metrics(results: List[BaselineResult]) -> Dict[str, ComparisonMetrics]:
    """Calculate metrics for each baseline - ALL based on valid cases only"""
    
    baseline_metrics = {}
    
    # Group results by baseline
    baseline_groups = {}
    for result in results:
        if result.baseline_name not in baseline_groups:
            baseline_groups[result.baseline_name] = []
        baseline_groups[result.baseline_name].append(result)
    
    # Calculate metrics for each baseline
    for baseline_name, baseline_results in baseline_groups.items():
        
        # Valid cases exclude Type 2 (no reliable ground truth)
        valid_results = [r for r in baseline_results if r.case_type != 2]
        valid_cases = len(valid_results)
        
        # ALL other metrics based on valid cases only
        valid_successful_cases = len([r for r in valid_results if r.has_result])
        correct_cases = len([r for r in valid_results if r.case_type == 3])
        
        # Calculate rates based on valid cases
        success_rate = valid_successful_cases / valid_cases if valid_cases > 0 else 0
        accuracy_rate = correct_cases / valid_cases if valid_cases > 0 else 0
        accuracy_among_successful = correct_cases / valid_successful_cases if valid_successful_cases > 0 else 0
        
        # Type distribution (still shows all types for reference)
        type_distribution = {i: 0 for i in range(1, 6)}
        for result in baseline_results:
            type_distribution[result.case_type] += 1
        
        baseline_metrics[baseline_name] = ComparisonMetrics(
            baseline_name=baseline_name,
            valid_cases=valid_cases,
            valid_successful_cases=valid_successful_cases,
            correct_cases=correct_cases,
            success_rate=success_rate,
            accuracy_rate=accuracy_rate,
            accuracy_among_successful=accuracy_among_successful,
            type_distribution=type_distribution
        )
    
    return baseline_metrics

def create_enhanced_summary_table(baseline_metrics: Dict[str, ComparisonMetrics], output_dir: Path):
    """Create enhanced text-based summary table"""
    
    # Create summary table
    table_path = output_dir / 'baseline_comparison_table.txt'
    with open(table_path, 'w', encoding='utf-8') as f:
        f.write("BASELINE PERFORMANCE COMPARISON TABLE\n")
        f.write("=" * 100 + "\n\n")
        
        # Header
        f.write(f"{'Baseline':<25} {'Valid':<8} {'Success':<8} {'Correct':<8} {'Success%':<10} {'Accuracy%':<10} {'Acc/Success%':<12}\n")
        f.write("-" * 100 + "\n")
        
        # Sort by accuracy rate
        sorted_baselines = sorted(baseline_metrics.items(), key=lambda x: x[1].accuracy_rate, reverse=True)
        
        for name, metrics in sorted_baselines:
            f.write(f"{name:<25} {metrics.valid_cases:<8} {metrics.valid_successful_cases:<8} "
                   f"{metrics.correct_cases:<8} {metrics.success_rate*100:<9.1f}% {metrics.accuracy_rate*100:<9.1f}% "
                   f"{metrics.accuracy_among_successful*100:<11.1f}%\n")
        
        f.write(f"\nLegend:\n")
        f.write(f"- Valid: Cases with reliable ground truth (excludes Type 2)\n")
        f.write(f"- Success: Cases that produced any result (within valid cases)\n")
        f.write(f"- Correct: Cases that matched ground truth (within valid cases)\n")
        f.write(f"- Success%: Percentage of valid cases that produced any result\n")
        f.write(f"- Accuracy%: Percentage of valid cases that were correct\n")
        f.write(f"- Acc/Success%: Percentage of successful cases that were correct\n\n")
        
        f.write("DETAILED CASE TYPE DISTRIBUTION:\n")
        f.write("=" * 90 + "\n")
        type_names = {1: "Data Issues", 2: "No GT", 3: "Success", 4: "Mismatch", 5: "Failed"}
        
        # Header for type distribution
        f.write(f"{'Baseline':<25}")
        for i in range(1, 6):
            f.write(f" {type_names[i]:<10}")
        f.write("\n")
        f.write("-" * 90 + "\n")
        
        total_cases = sum(sum(metrics.type_distribution.values()) for metrics in baseline_metrics.values()) // len(baseline_metrics)
        for name, metrics in sorted_baselines:
            f.write(f"{name:<25}")
            for i in range(1, 6):
                count = metrics.type_distribution[i]
                pct = (count / total_cases) * 100
                f.write(f" {count:2d}({pct:4.1f}%)")
            f.write("\n")
        
        # Add analysis of problem areas
        f.write(f"\nKEY INSIGHTS:\n")
        f.write("=" * 50 + "\n")
        
        for name, metrics in sorted_baselines:
            f.write(f"\n{name}:\n")
            
            # Calculate key metrics
            failure_rate = (metrics.type_distribution[5] / total_cases) * 100
            mismatch_rate = (metrics.type_distribution[4] / total_cases) * 100
            no_gt_rate = (metrics.type_distribution[2] / total_cases) * 100
            
            if failure_rate > 10:
                f.write(f"  - High failure rate: {failure_rate:.1f}% (implementation issues)\n")
            if mismatch_rate > 15:
                f.write(f"  - High mismatch rate: {mismatch_rate:.1f}% (accuracy issues)\n")
            if no_gt_rate > 20:
                f.write(f"  - Many cases without ground truth: {no_gt_rate:.1f}% (excluded from metrics)\n")
            
            # Success analysis
            if metrics.accuracy_among_successful > 0.8:
                f.write(f"  - High accuracy when successful: {metrics.accuracy_among_successful*100:.1f}%\n")
            elif metrics.accuracy_among_successful < 0.5:
                f.write(f"  - Low accuracy even when successful: {metrics.accuracy_among_successful*100:.1f}%\n")
    
    print(f"Enhanced summary table saved to {table_path}")

def print_enhanced_console_summary(baseline_metrics: Dict[str, ComparisonMetrics], database_count: int, total_time: float):
    """Print enhanced summary to console"""
    
    print("\n" + "=" * 80)
    print("COMPREHENSIVE BASELINE EVALUATION RESULTS")
    print("=" * 80)
    print(f"Total databases evaluated: {database_count}")
    print(f"Processing time: {total_time:.1f} seconds\n")
    
    # Show the valid cases summary
    total_cases = sum(sum(m.type_distribution.values()) for m in baseline_metrics.values()) // len(baseline_metrics)
    valid_cases = sum(m.valid_cases for m in baseline_metrics.values()) // len(baseline_metrics)
    excluded_cases = total_cases - valid_cases
    
    print(f"EVALUATION SCOPE:")
    print(f"- Total cases per method: {total_cases}")
    print(f"- Valid cases per method: {valid_cases}")
    print(f"- Excluded cases (Type 2): {excluded_cases} ({excluded_cases/total_cases*100:.1f}%)")
    print(f"- All metrics calculated on valid cases only\n")
    
    print("BASELINE PERFORMANCE RANKING (by accuracy on valid cases):")
    print("-" * 70)
    sorted_baselines = sorted(baseline_metrics.items(), key=lambda x: x[1].accuracy_rate, reverse=True)
    
    for i, (name, metrics) in enumerate(sorted_baselines, 1):
        print(f"{i}. {name}")
        print(f"   Success Rate: {metrics.success_rate:.1%} ({metrics.valid_successful_cases}/{metrics.valid_cases})")
        print(f"   Accuracy Rate: {metrics.accuracy_rate:.1%} ({metrics.correct_cases}/{metrics.valid_cases})")
        print(f"   Accuracy when Successful: {metrics.accuracy_among_successful:.1%}")
        
        # Quick insight
        if metrics.success_rate > 0.8 and metrics.accuracy_rate > 0.8:
            print(f"   → Reliable and accurate")
        elif metrics.success_rate > 0.8:
            print(f"   → Reliable but accuracy issues")
        elif metrics.accuracy_rate > 0.8:
            print(f"   → Accurate when works, but reliability issues")
        else:
            print(f"   → Both reliability and accuracy challenges")
        print()
    
    # Additional analysis
    print("KEY FINDINGS:")
    print("-" * 30)
    
    best_accuracy = sorted_baselines[0]
    best_success = max(baseline_metrics.items(), key=lambda x: x[1].success_rate)
    
    print(f"• Highest accuracy: {best_accuracy[0]} ({best_accuracy[1].accuracy_rate:.1%})")
    print(f"• Highest success rate: {best_success[0]} ({best_success[1].success_rate:.1%})")
    
    # Check if they're different
    if best_accuracy[0] != best_success[0]:
        print(f"• Trade-off detected: Different methods excel at reliability vs accuracy")
    
    # Count total successful cases across all methods
    total_valid = sum(m.valid_cases for m in baseline_metrics.values())
    total_valid_successes = sum(m.valid_successful_cases for m in baseline_metrics.values())
    total_correct = sum(m.correct_cases for m in baseline_metrics.values())
    
    print(f"• Overall success rate: {total_valid_successes/total_valid:.1%} (on valid cases)")
    print(f"• Overall accuracy rate: {total_correct/total_valid:.1%} (on valid cases)")

def test_extraction_sample(baseline_dirs: Dict[str, Path], database_names: List[str], num_samples: int = 3):
    """Test LLM majority vote extraction on a small sample and show results"""
    
    print("\n" + "="*70)
    print("TESTING LLM MAJORITY VOTE EXTRACTION ON SAMPLE FILES")
    print("="*70)
    
    for baseline_name, baseline_dir in baseline_dirs.items():
        print(f"\n{baseline_name}:")
        print("-" * 40)
        
        sample_count = 0
        for db_name in database_names:
            if sample_count >= num_samples:
                break
                
            code_output_file = baseline_dir / db_name / "code_output.txt"
            if code_output_file.exists():
                print(f"\n  Testing {db_name}:")
                
                # Test LLM majority vote extraction
                result = extract_baseline_result_llm_majority_vote(code_output_file)
                print(f"    Final Result: {result}")
                
                # Show a snippet of the file for context
                try:
                    with open(code_output_file, 'r') as f:
                        lines = f.readlines()
                        preview = ' '.join(line.strip() for line in lines[-5:] if line.strip())[:120]
                        print(f"    File End Preview: {preview}...")
                except:
                    pass
                
                sample_count += 1
                
        if sample_count == 0:
            print(f"    No files found for {baseline_name}")

def main():
    parser = argparse.ArgumentParser(description="Comprehensive Baseline Evaluation - Valid Cases Only with LLM Majority Vote")
    parser.add_argument("--synthetic_dir", type=str, 
                       default="/u/frankhlchi/text2opt_dataset_alternating_optimization_parallel",
                       help="Path to synthetic data directory")
    
    # Generic baseline arguments - up to 6 baselines
    parser.add_argument("--b1_name", type=str, default="", help="Name for baseline 1")
    parser.add_argument("--b1_folder", type=str, default="", help="Path to baseline 1 results")
    parser.add_argument("--b2_name", type=str, default="", help="Name for baseline 2")
    parser.add_argument("--b2_folder", type=str, default="", help="Path to baseline 2 results")
    parser.add_argument("--b3_name", type=str, default="", help="Name for baseline 3")
    parser.add_argument("--b3_folder", type=str, default="", help="Path to baseline 3 results")
    parser.add_argument("--b4_name", type=str, default="", help="Name for baseline 4")
    parser.add_argument("--b4_folder", type=str, default="", help="Path to baseline 4 results")
    parser.add_argument("--b5_name", type=str, default="", help="Name for baseline 5")
    parser.add_argument("--b5_folder", type=str, default="", help="Path to baseline 5 results")
    parser.add_argument("--b6_name", type=str, default="", help="Name for baseline 6")
    parser.add_argument("--b6_folder", type=str, default="", help="Path to baseline 6 results")
    
    # Legacy arguments for backward compatibility (deprecated)
    parser.add_argument("--or_llm_agent_dir", type=str, default="", help="[DEPRECATED] Use --b1_name and --b1_folder instead")
    parser.add_argument("--optimus_dir", type=str, default="", help="[DEPRECATED] Use --b2_name and --b2_folder instead")
    parser.add_argument("--simple_zero_shot_dir", type=str, default="", help="[DEPRECATED] Use --b3_name and --b3_folder instead")
    parser.add_argument("--chain_of_experts_dir", type=str, default="", help="[DEPRECATED] Use --b4_name and --b4_folder instead")
    parser.add_argument("--hierarchical_graph_agent_dir", type=str, default="", help="[DEPRECATED] Use --b5_name and --b5_folder instead")
    
    parser.add_argument("--output_dir", type=str,
                       default="/u/frankhlchi/comprehensive_evaluation_results",
                       help="Output directory for results")
    parser.add_argument("--max_workers", type=int, 
                       default=min(8, max(1, mp.cpu_count() - 1)),
                       help="Maximum number of parallel workers for database processing")
    parser.add_argument("--test_mode", action="store_true",
                       help="Run in test mode with sample extraction only")
    parser.add_argument("--sample_only", type=int, default=0,
                       help="Process only N databases for testing (0 = process all)")
    parser.add_argument("--auto_proceed", action="store_true",
                       help="Skip confirmation and proceed automatically")
    
    args = parser.parse_args()
    
    # Build baseline directories from arguments
    baseline_dirs = {}
    
    # Check for new generic baseline arguments
    baseline_configs = [
        (args.b1_name, args.b1_folder),
        (args.b2_name, args.b2_folder),
        (args.b3_name, args.b3_folder),
        (args.b4_name, args.b4_folder),
        (args.b5_name, args.b5_folder),
        (args.b6_name, args.b6_folder)
    ]
    
    # Add baselines that have both name and folder specified
    for name, folder in baseline_configs:
        if name.strip() and folder.strip():
            baseline_dirs[name.strip()] = Path(folder.strip())
    
    # Handle legacy arguments for backward compatibility
    if not baseline_dirs:
        print("No baselines specified with --bX_name/--bX_folder arguments.")
        print("Checking for legacy arguments...")
        
        legacy_configs = [
            ("OR-LLM-Agent", args.or_llm_agent_dir),
            ("OptiMUS", args.optimus_dir),
            ("Simple Zero-Shot", args.simple_zero_shot_dir),
            ("Chain-of-Experts", args.chain_of_experts_dir),
            ("Hierarchical-Graph-Agent", args.hierarchical_graph_agent_dir)
        ]
        
        for name, folder in legacy_configs:
            if folder.strip():
                baseline_dirs[name] = Path(folder.strip())
                print(f"Using legacy argument: {name} -> {folder}")
    
    if not baseline_dirs:
        print("ERROR: No baselines specified!")
        print("Usage examples:")
        print("  python script.py --b1_name 'Method-A' --b1_folder '/path/to/results1' --b2_name 'Method-B' --b2_folder '/path/to/results2'")
        print("  python script.py --or_llm_agent_dir '/path/to/or_llm' --optimus_dir '/path/to/optimus'  # Legacy mode")
        sys.exit(1)
    
    print("Comprehensive Baseline Evaluation - Valid Cases Only with LLM Majority Vote")
    print("=" * 80)
    print(f"Synthetic data directory: {args.synthetic_dir}")
    print("Baseline configurations:")
    for name, path in baseline_dirs.items():
        print(f"  {name}: {path}")
    print(f"Output directory: {args.output_dir}")
    print(f"Max workers (database level): {args.max_workers}")
    print(f"Extraction method: LLM Majority Vote (5 attempts per file)")
    print(f"Metrics: All calculated on valid cases only (excludes Type 2)")
    
    # Setup paths
    synthetic_dir = Path(args.synthetic_dir)
    output_dir = Path(args.output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Validate directories
    if not synthetic_dir.exists():
        print(f"ERROR: Synthetic directory not found: {args.synthetic_dir}")
        sys.exit(1)
    
    valid_baselines = {}
    for name, path in baseline_dirs.items():
        if path.exists():
            valid_baselines[name] = path
        else:
            print(f"WARNING: {name} directory not found: {path}, skipping this baseline")
    
    if not valid_baselines:
        print("ERROR: No valid baseline directories found!")
        sys.exit(1)
    
    baseline_dirs = valid_baselines
    print(f"\nActive baselines: {list(baseline_dirs.keys())}")
    
    # Find common databases
    database_info = find_common_databases(synthetic_dir, baseline_dirs)
    
    if not database_info:
        print("ERROR: No common databases found!")
        sys.exit(1)
    
    print(f"\nFound {len(database_info)} common databases")
    
    # Limit databases for testing if requested
    if args.sample_only > 0:
        database_info = database_info[:args.sample_only]
        print(f"Limited to {len(database_info)} databases for testing")
    
    # Test LLM majority vote extraction on sample files
    print("\nTesting LLM majority vote extraction (5 attempts with voting)...")
    test_extraction_sample(baseline_dirs, [db[0] for db in database_info], num_samples=2)
    
    # Test mode - only run extraction tests
    if args.test_mode:
        print("\nTest mode completed. Use --help to see full evaluation options.")
        return
    
    # Ask user confirmation to proceed (unless auto_proceed is set)
    total_extractions = len(database_info) * len(baseline_dirs)
    total_llm_calls = total_extractions * 5  # 5 LLM calls per extraction
    print(f"\nReady to evaluate {len(database_info)} databases using LLM majority vote extraction.")
    print(f"Active baselines: {len(baseline_dirs)} ({', '.join(baseline_dirs.keys())})")
    print(f"Total extractions: {total_extractions}")
    print(f"Total LLM API calls: {total_llm_calls} (5 attempts per extraction with majority voting)")
    print(f"Parallel configuration: {args.max_workers} database workers")
    print(f"All metrics will be calculated on valid cases only")
    
    if not args.auto_proceed:
        try:
            response = input("Continue with full evaluation? (y/n): ").strip().lower()
            if response not in ['y', 'yes']:
                print("Evaluation cancelled.")
                return
        except KeyboardInterrupt:
            print("\nEvaluation cancelled by user.")
            return
    else:
        print("Auto-proceeding with full evaluation...")
    
    # Process databases in parallel
    print(f"\nStarting parallel evaluation...")
    print("Using LLM majority vote: 5 attempts per extraction with voting")
    print(f"Processing {len(database_info)} databases with {args.max_workers} parallel workers")
    
    all_results = []
    start_time = time.time()
    
    # Use ProcessPoolExecutor for database-level parallelism
    with ProcessPoolExecutor(max_workers=args.max_workers, initializer=init_worker) as executor:
        future_to_db = {
            executor.submit(evaluate_single_database, db_info): db_info[0]
            for db_info in database_info
        }
        
        completed = 0
        total = len(future_to_db)
        
        for future in as_completed(future_to_db):
            db_name = future_to_db[future]
            completed += 1
            try:
                results = future.result(timeout=1200)  # 20 minute timeout per database (5x LLM calls)
                all_results.extend(results)
                print(f"Progress: {completed}/{total} ({completed/total*100:.1f}%) - Completed {db_name}")
            except Exception as exc:
                print(f'ERROR: {db_name} generated an exception: {exc}')
    
    total_time = time.time() - start_time
    
    # Calculate metrics based on valid cases only
    print("Calculating performance metrics (valid cases only)...")
    baseline_metrics = calculate_comparison_metrics(all_results)
    
    # Write detailed CSV
    csv_path = output_dir / 'detailed_comparison_valid_cases.csv'
    print(f"Writing detailed results to {csv_path}")
    with open(csv_path, 'w', newline='', encoding='utf-8') as csvfile:
        fieldnames = [
            'database_name', 'baseline_name', 'ground_truth_value', 'baseline_value',
            'has_result', 'case_type', 'case_description', 'accuracy_match', 'solver_consistency'
        ]
        
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
        writer.writeheader()
        
        for result in sorted(all_results, key=lambda x: (x.database_name, x.baseline_name)):
            writer.writerow({
                'database_name': result.database_name,
                'baseline_name': result.baseline_name,
                'ground_truth_value': result.ground_truth_value,
                'baseline_value': result.baseline_value,
                'has_result': result.has_result,
                'case_type': result.case_type,
                'case_description': result.case_description,
                'accuracy_match': result.accuracy_match,
                'solver_consistency': result.solver_consistency
            })
    
    # Create enhanced summary report
    summary_path = output_dir / 'comparison_summary_valid_cases.txt'
    print(f"Writing summary report to {summary_path}")
    with open(summary_path, 'w', encoding='utf-8') as f:
        f.write("COMPREHENSIVE BASELINE COMPARISON REPORT\n")
        f.write("=" * 60 + "\n\n")
        f.write(f"Evaluation completed: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
        f.write(f"Total databases evaluated: {len(database_info)}\n")
        f.write(f"Active baselines: {len(baseline_dirs)} ({', '.join(baseline_dirs.keys())})\n")
        f.write(f"Processing time: {total_time:.1f} seconds\n")
        f.write(f"Average time per database: {total_time/len(database_info):.1f} seconds\n")
        f.write(f"Extraction method: LLM Majority Vote (5 attempts with voting per extraction)\n")
        f.write(f"Parallel configuration: {args.max_workers} DB workers\n")
        f.write(f"Total LLM API calls: {total_llm_calls}\n")
        f.write(f"Metrics: All calculated on valid cases only (excludes Type 2)\n\n")
        
        # Show metric calculation summary
        if baseline_metrics:
            total_cases = sum(sum(m.type_distribution.values()) for m in baseline_metrics.values()) // len(baseline_metrics)
            valid_cases = sum(m.valid_cases for m in baseline_metrics.values()) // len(baseline_metrics)
            excluded_cases = total_cases - valid_cases
            
            f.write("EVALUATION SCOPE:\n")
            f.write("-" * 20 + "\n")
            f.write(f"Total cases per method: {total_cases}\n")
            f.write(f"Valid cases per method: {valid_cases}\n")
            f.write(f"Excluded cases (Type 2): {excluded_cases} ({excluded_cases/total_cases*100:.1f}%)\n")
            f.write(f"Exclusion reason: No reliable ground truth consensus among professional solvers\n")
            f.write(f"Benefit: Fairer evaluation by focusing on solvable problems only\n\n")
        
        f.write("BASELINE PERFORMANCE SUMMARY:\n")
        f.write("-" * 40 + "\n")
        for name, metrics in baseline_metrics.items():
            f.write(f"{name}:\n")
            f.write(f"  Valid Cases: {metrics.valid_cases} (evaluable)\n")
            f.write(f"  Successful: {metrics.valid_successful_cases} ({metrics.success_rate:.1%})\n")
            f.write(f"  Correct: {metrics.correct_cases} ({metrics.accuracy_rate:.1%})\n")
            f.write(f"  Accuracy among Successful: {metrics.accuracy_among_successful:.1%}\n")
            f.write(f"  Type Distribution: {metrics.type_distribution}\n\n")
        
        f.write("METHODOLOGY NOTES:\n")
        f.write("-" * 40 + "\n")
        f.write("• Extraction Method: LLM Majority Vote\n")
        f.write("  - 5 independent LLM calls per extraction\n")
        f.write("  - Majority voting to determine final result\n")
        f.write("  - Clear classification rules prioritizing ERROR detection\n")
        f.write("  - Regex fallback if all LLM calls fail\n")
        f.write("• Metric Calculation: All metrics based on valid cases only\n")
        f.write("  - Valid cases: Cases with reliable ground truth (excludes Type 2)\n")
        f.write("  - Success rate: Successful cases / Valid cases\n")
        f.write("  - Accuracy rate: Correct cases / Valid cases\n")
        f.write("  - Prevents unfair penalization for unsolvable problems\n")
        f.write("• Case Types:\n")
        f.write("  - Type 3 (Success): Baseline result matches ground truth exactly\n")
        f.write("  - Type 4 (Mismatch): Baseline produces result but differs from ground truth\n")
        f.write("  - Type 5 (Failed): Baseline fails to produce any result\n")
        f.write("  - Type 2 (No GT): No reliable ground truth - excluded from all calculations\n\n")
        
        f.write("FILES GENERATED:\n")
        f.write(f"- Detailed results: {csv_path}\n")
        f.write(f"- Summary report: {summary_path}\n")
        f.write(f"- Comparison table: {output_dir / 'baseline_comparison_table.txt'}\n")
    
    # Create enhanced summary table
    print("Creating enhanced summary table...")
    create_enhanced_summary_table(baseline_metrics, output_dir)
    
    # Print enhanced console summary
    print_enhanced_console_summary(baseline_metrics, len(database_info), total_time)
    
    print(f"\nResults saved to: {output_dir}")
    print(f"View detailed comparison: {csv_path}")
    print(f"View summary report: {summary_path}")
    print(f"View comparison table: {output_dir / 'baseline_comparison_table.txt'}")
    
    # Summary statistics
    print(f"\nFINAL STATISTICS:")
    print(f"- Total evaluations: {len(all_results)}")
    print(f"- Active baselines: {len(baseline_dirs)}")
    print(f"- Total LLM calls made: {total_llm_calls}")
    print(f"- Successful extractions: {len([r for r in all_results if r.has_result])}")
    print(f"- Extraction success rate: {len([r for r in all_results if r.has_result])/len(all_results)*100:.1f}%")
    print(f"- Processing time: {total_time:.1f} seconds")
    print(f"- Average time per database: {total_time/len(database_info):.1f} seconds")
    print(f"- Throughput: {len(all_results)/total_time:.1f} extractions/second")
    print(f"- Extraction method: LLM Majority Vote (5 attempts with voting)")
    print(f"- Evaluation approach: Valid cases only for fair comparison")

if __name__ == "__main__":
    main()