"""
Evaluation script for EMBGuard heldout set results
Evaluates model outputs against ground truth from CSV files

Evaluation criteria:
- HR: potential_risk should be "unsafe" (hazard present)
- NHR, HNR: potential_risk should be "safe" (no hazard)
- risk_type: Should match CSV Category
- hazard: Evaluated using LLM-as-a-judge (compared with CSV Related Hazard)
"""

import json
import sys
import argparse
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
from tqdm import tqdm
from multiprocessing import Process, Manager

# Add project root to path
project_root = Path(__file__).resolve().parent.parent.parent
project_root_str = str(project_root)
if project_root_str not in sys.path:
    sys.path.insert(0, project_root_str)

from utils.config import get_config
from utils.path import get_project_path
from src.models import create_model, BaseLLMModel

# Import ResultsEvaluator from results_evaluator (reuse the same logic)
from src.evals.results_evaluator import ResultsEvaluator


def get_output_dir_from_results_file(results_file: str) -> Path:
    """
    Extract provider and model name from results file path and create output directory
    
    Args:
        results_file: Path to JSONL results file
        
    Returns:
        Path to output directory (results/{provider}_{model_name}/)
    """
    results_path = Path(results_file)
    
    # Extract filename (e.g., "openai_gpt-4o_mini_dataset_safe_no-few-shot_non-thinking_results.jsonl")
    filename = results_path.stem  # Remove .jsonl extension
    
    # Parse filename to extract provider and model name
    # Format: {provider}_{model_name}_{csv_name}_{condition_suffix}_results
    # Example: openai_gpt-4o_mini_dataset_safe_no-few-shot_non-thinking_results
    
    # Remove "_results" suffix if present
    if filename.endswith("_results"):
        filename = filename[:-8]
    
    # Split by underscore
    parts = filename.split("_")
    
    if len(parts) < 2:
        # Fallback: use parent directory name
        parent_dir = results_path.parent.name
        if "_" in parent_dir:
            provider_model = parent_dir.split("_", 1)[0]
        else:
            provider_model = "unknown"
    else:
        # Look for common providers
        providers = ["openai", "openrouter", "claude", "gemini", "vllm"]
        provider = None
        model_parts = []
        
        # Find provider (should be first part)
        if parts[0].lower() in providers:
            provider = parts[0].lower()
            # Model name is everything after provider until we hit dataset keywords
            dataset_keywords = ["dataset", "safe", "unsafe"]
            condition_keywords = ["no-few-shot", "few-shot", "non-thinking", "thinking"]
            
            for i in range(1, len(parts)):
                part_lower = parts[i].lower()
                # Stop if we hit dataset or condition keywords
                if part_lower in dataset_keywords or part_lower in condition_keywords:
                    break
                # For condition keywords with hyphens, check if it matches exactly
                # Don't stop just because there's a hyphen - model names can have hyphens too
                # Only stop if it's a known condition keyword
                if part_lower in condition_keywords:
                    break
                model_parts.append(parts[i])
        
        if provider and model_parts:
            model_name = "_".join(model_parts)
            provider_model = f"{provider}_{model_name}"
        else:
            # Fallback: use first two parts
            provider_model = "_".join(parts[:2])
    
    # Create output directory: results/heldout_set/{provider_model}/
    project_path = get_project_path()
    output_dir = project_path / "results" / "heldout_set" / provider_model
    return output_dir


def evaluate_heldout_results_file(
    results_file: str,
    output_file: Optional[str] = None,
    judge_provider: str = "openai",
    judge_model: str = "gpt-4o-mini",
    judge_config: Optional[Dict[str, Any]] = None,
    num_workers: int = 1,
) -> Dict[str, Any]:
    """
    Evaluate heldout set results file
    
    Args:
        results_file: Path to JSONL results file
        output_file: Optional path to save evaluation results (auto-generated if None)
        judge_provider: LLM provider for judge model
        judge_model: Model name for judge
        judge_config: Additional config for judge model
        num_workers: Number of worker processes for parallel evaluation
        
    Returns:
        Dictionary with overall evaluation statistics, evaluations, and original results
    """
    results_path = Path(results_file)
    if not results_path.exists():
        raise FileNotFoundError(f"Results file not found: {results_file}")
    
    # Initialize evaluator (reuse ResultsEvaluator from results_evaluator)
    evaluator = ResultsEvaluator(
        judge_provider=judge_provider,
        judge_model=judge_model,
        judge_config=judge_config,
    )
    
    # Auto-generate output file path if not provided
    if output_file is None:
        output_dir = get_output_dir_from_results_file(results_file)
        output_dir.mkdir(parents=True, exist_ok=True)
        
        # Create output filename from results filename
        results_filename = results_path.stem  # e.g., "openai_gpt-4o_mini_dataset_safe_no-few-shot_non-thinking_results"
        # Replace "_results" with "_evaluation.json"
        if results_filename.endswith("_results"):
            evaluation_filename = results_filename[:-8] + "_evaluation.json"
        else:
            evaluation_filename = results_filename + "_evaluation.json"
        
        output_file = str(output_dir / evaluation_filename)
    
    # Run evaluation using the same logic as test set
    eval_output = evaluator.evaluate_results_file(
        results_file=results_file,
        output_file=output_file,
        num_workers=num_workers,
    )
    
    return eval_output


def main():
    parser = argparse.ArgumentParser(description="Evaluate EMBGuard heldout set results")
    parser.add_argument(
        "--results-file",
        type=str,
        required=True,
        help="Path to JSONL results file to evaluate"
    )
    parser.add_argument(
        "--output-file",
        type=str,
        default=None,
        help="Path to save evaluation results (JSON format). If not specified, auto-generated to results/{provider}_{model_name}/"
    )
    parser.add_argument(
        "--judge-provider",
        type=str,
        default="openai",
        help="LLM provider for judge model (default: openai)"
    )
    parser.add_argument(
        "--judge-model",
        type=str,
        default="gpt-4o-mini",
        help="Model name for judge (default: gpt-4o-mini)"
    )
    parser.add_argument(
        "--judge-temperature",
        type=float,
        default=0.0,
        help="Temperature for judge model (default: 0.0)"
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=1,
        help="Number of worker processes for parallel judge evaluation (default: 1)"
    )
    
    args = parser.parse_args()
    
    # Run evaluation
    eval_output = evaluate_heldout_results_file(
        results_file=args.results_file,
        output_file=args.output_file,
        judge_provider=args.judge_provider,
        judge_model=args.judge_model,
        judge_config={
            "temperature": args.judge_temperature,
        },
        num_workers=args.num_workers,
    )
    
    stats = eval_output["statistics"]
    evaluations = eval_output["evaluations"]
    original_results = eval_output["original_results"]
    
    # Print statistics
    print("\n" + "="*60)
    print("Evaluation Statistics")
    print("="*60)
    print(f"Total items: {stats['total']}")
    print(f"\nOverall Accuracy: {stats['overall_accuracy']:.4f} ({stats['overall_accuracy']*100:.2f}%)")
    print(f"Potential Risk Accuracy: {stats['potential_risk_accuracy']:.4f} ({stats['potential_risk_accuracy']*100:.2f}%)")
    
    # Risk Type Accuracy (only for evaluated items)
    if stats.get('risk_type_accuracy') is not None:
        risk_type_total = stats.get('risk_type_total', 0)
        print(f"Risk Type Accuracy: {stats['risk_type_accuracy']:.4f} ({stats['risk_type_accuracy']*100:.2f}%) [Evaluated: {risk_type_total}/{stats['total']}]")
    else:
        print(f"Risk Type Accuracy: N/A (no items evaluated)")
    
    # Hazard Accuracy (only for evaluated items)
    if stats.get('hazard_accuracy') is not None:
        hazard_total = stats.get('hazard_total', 0)
        print(f"Hazard Accuracy: {stats['hazard_accuracy']:.4f} ({stats['hazard_accuracy']*100:.2f}%) [Evaluated: {hazard_total}/{stats['total']}]")
    else:
        print(f"Hazard Accuracy: N/A (no items evaluated)")
    
    print("\n" + "-"*60)
    print("Statistics by Type")
    print("-"*60)
    for result_type, type_stats in stats["by_type"].items():
        print(f"\n{result_type}:")
        print(f"  Total: {type_stats['total']}")
        print(f"  Overall Accuracy: {type_stats['overall_accuracy']:.4f} ({type_stats['overall_accuracy']*100:.2f}%)")
        print(f"  Potential Risk Accuracy: {type_stats['potential_risk_accuracy']:.4f} ({type_stats['potential_risk_accuracy']*100:.2f}%)")
        
        # Risk Type Accuracy
        if type_stats.get('risk_type_accuracy') is not None:
            risk_type_total = type_stats.get('risk_type_total', 0)
            print(f"  Risk Type Accuracy: {type_stats['risk_type_accuracy']:.4f} ({type_stats['risk_type_accuracy']*100:.2f}%) [Evaluated: {risk_type_total}/{type_stats['total']}]")
        else:
            print(f"  Risk Type Accuracy: N/A (no items evaluated)")
        
        # Hazard Accuracy
        if type_stats.get('hazard_accuracy') is not None:
            hazard_total = type_stats.get('hazard_total', 0)
            print(f"  Hazard Accuracy: {type_stats['hazard_accuracy']:.4f} ({type_stats['hazard_accuracy']*100:.2f}%) [Evaluated: {hazard_total}/{type_stats['total']}]")
        else:
            print(f"  Hazard Accuracy: N/A (no items evaluated)")


if __name__ == "__main__":
    main()

