#!/usr/bin/env python
# coding: utf-8
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Judge's Verdict Scoring Script
This script handles the evaluation of items using Judge's Verdict.
It reads evaluation data and produces judge scores that can be consumed by correlation analysis.

IMPORTANT: Model identifiers in the config YAML must exactly match the desired output folder names.
The script will create output directories using the model identifier as the folder name.
For example, a model with identifier 'gpt-4o' will have results saved to 'results/gpt-4o/'.
"""

import os
import json
import argparse
import time
from pathlib import Path
from typing import List

import pandas as pd
from tqdm import tqdm

from ragas import evaluate
from ragas.run_config import RunConfig
from ragas.llms import LangchainLLMWrapper
from ragas.metrics import AnswerAccuracy
from ragas import EvaluationDataset, SingleTurnSample
# Token tracking imports removed

# Import judge configuration manager directly
from ..judge_config_manager import JudgeConfigManager, ServingFramework
from ..utils.litellm_ragas_wrapper import create_ragas_litellm_wrapper


def load_evaluation_data(file_path: str) -> pd.DataFrame:
    """Load and preprocess evaluation data from JSON file."""
    with open(file_path, 'r', encoding='utf-8') as f:
        data = json.load(f)
    
    # Convert directly to DataFrame
    df = pd.DataFrame(data)
    
    # Ensure compatibility with the rest of the code
    # Map new column names to expected column names
    column_mapping = {
        'question': 'Question',
        'gt_answer': 'Ground_Truth_Answer',
        'gen_answer': 'Generated_Answer',
        'ground_truth_contexts': 'Ground_Truth_Contexts'
    }
    
    # Rename columns for compatibility with existing code
    df = df.rename(columns=column_mapping)
    
    # Add Ground_Truth_Contexts column if it doesn't exist
    if 'Ground_Truth_Contexts' not in df.columns:
        df['Ground_Truth_Contexts'] = [[] for _ in range(len(df))]
    
    return df


def process_evaluations(df: pd.DataFrame) -> pd.DataFrame:
    """Process evaluations to get unique items for scoring."""
    # Get unique items (one row per question/answer pair)
    unique_items = df[['Question', 'Ground_Truth_Answer', 'Generated_Answer', 'dataset_name']].drop_duplicates()
    return unique_items.reset_index(drop=True)


def create_ragas_dataset(df: pd.DataFrame) -> EvaluationDataset:
    """Convert DataFrame to RAGAS EvaluationDataset."""
    return EvaluationDataset([
        SingleTurnSample(
            user_input=str(df.iloc[i].Question),
            reference=str(df.iloc[i].Ground_Truth_Answer),
            response=str(df.iloc[i].Generated_Answer),
        ) for i in range(len(df))
    ])


def run_judge_evaluations(
    eval_dataset: EvaluationDataset, 
    judges: List[str], 
    output_dir: str,
    judge_workers: dict = None,
    default_workers: int = 16,
    max_trials: int = 3,
    config_path: str = None
) -> None:
    """Run evaluations with different judge models.
    
    Args:
        eval_dataset: The dataset to evaluate
        judges: List of judge model names
        output_dir: Directory to save results
        judge_workers: Dictionary of judge to worker count mappings (deprecated, use YAML config)
        default_workers: Default number of workers if not specified (deprecated, use YAML config)
        max_trials: Maximum number of trials per judge
        config_path: Optional path to YAML configuration file
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Load configuration manager
    if config_path:
        config_manager = JudgeConfigManager(config_path)
    else:
        config_manager = JudgeConfigManager()

    print(f"\nRunning evaluations for {len(judges)} judge(s)")
    for judge in tqdm(judges, desc="Evaluating judges", position=0):
        tqdm.write(f"\n{'='*60}")
        tqdm.write(f"Running {judge}")
        tqdm.write(f"{'='*60}")
        # Use the judge name directly as the folder name
        # Config should already have folder-safe names that match results directory
        sanitized_judge_name = judge

        # Create judge subdirectory if it doesn't exist
        judge_output_dir = os.path.join(output_dir, sanitized_judge_name)
        os.makedirs(judge_output_dir, exist_ok=True)
        
        # Find the next trial to run
        next_trial = 1
        while os.path.exists(f"{judge_output_dir}/trial{next_trial}.json"):
            next_trial += 1

        # If all trials exist, skip this judge
        if next_trial > max_trials:
            tqdm.write(f"Skipping {judge} - all {max_trials} trials already exist")
            continue
        
        # Get model configuration from YAML
        model_config = config_manager.get_model(judge)
        
        if not model_config:
            tqdm.write(f"ERROR: Model {judge} not found in configuration file")
            tqdm.write(f"Please add it to {config_manager.config_path} or use a different config file")
            continue
        
        # Create LLM using configuration
        try:
            # Use direct RAGAS wrapper for liteLLM, bypassing LangChain
            if model_config.framework == ServingFramework.LITELLM:
                # Create direct RAGAS wrapper - no LangChain needed
                wrapped_model = create_ragas_litellm_wrapper(
                    model=model_config.model,
                    temperature=model_config.temperature,
                    max_tokens=model_config.max_tokens,
                    api_key=model_config.get_api_key(),
                    api_base=model_config.get_effective_base_url(),
                    api_version=model_config.api_version,
                    timeout=model_config.timeout,
                    max_retries=model_config.max_retries
                )
                tqdm.write(f"Using liteLLM framework (direct, no LangChain) for {judge}")
                llm = None  # No LangChain LLM needed
            else:
                # Use LangChain for other frameworks
                llm = config_manager.create_langchain_llm(judge)
                wrapped_model = None  # Will be created later
                tqdm.write(f"Using {model_config.framework.value} framework for {judge}")
            
            if model_config.deployment:
                tqdm.write(f"  Deployment: {model_config.deployment}")
            if model_config.api_version:
                tqdm.write(f"  API Version: {model_config.api_version}")
        except Exception as e:
            tqdm.write(f"Failed to create LLM for {judge}: {e}")
            continue
        
        # Token tracking removed

        # Initialize variables for original methods (will be set later)
        model_original_generate = None
        model_original_agenerate_text = None
        
        # Run the remaining trials
        for trial in range(next_trial, max_trials + 1):
            checkpoint_path = f"{judge_output_dir}/trial{trial}.json"
            
            # Create a wrapped model for evaluation (only for non-liteLLM)
            if wrapped_model is None and llm is not None:
                wrapped_model = LangchainLLMWrapper(llm)
            
            # Save the original methods ONCE (first trial only)
            if wrapped_model and model_original_generate is None:
                if hasattr(wrapped_model, '_original_generate'):
                    # Already saved from a previous run, use the saved originals
                    model_original_generate = wrapped_model._original_generate
                    model_original_agenerate_text = wrapped_model._original_agenerate_text
                else:
                    # First time, save the originals
                    model_original_generate = wrapped_model.generate
                    model_original_agenerate_text = getattr(wrapped_model, 'agenerate_text', None)
                    # Store them on the object for future trials
                    wrapped_model._original_generate = model_original_generate
                    wrapped_model._original_agenerate_text = model_original_agenerate_text
            
            # Restore original methods at the start of each trial
            if wrapped_model and model_original_generate:
                wrapped_model.generate = model_original_generate
                if model_original_agenerate_text:
                    wrapped_model.agenerate_text = model_original_agenerate_text
            
            # Special handling for o1 models that don't support temperature
            if 'o1' in judge and wrapped_model and model_original_generate:
                # Create patched methods that ignore temperature
                def patched_generate(prompt, n=1, temperature=None, stop=None, callbacks=None):
                    """Patched generate that ignores temperature for o1 models."""
                    # Don't pass temperature for o1 models
                    return model_original_generate(prompt, n=n, stop=stop, callbacks=callbacks)
                
                async def patched_agenerate_text(prompt, n=1, temperature=None, stop=None, callbacks=None):
                    """Patched async generate that ignores temperature for o1 models."""
                    # Always call the original async method without temperature
                    return await model_original_agenerate_text(prompt, n=n, stop=stop, callbacks=callbacks)
                
                # Apply the patches
                wrapped_model.generate = patched_generate
                if model_original_agenerate_text:
                    wrapped_model.agenerate_text = patched_agenerate_text
            
            # Get the total number of samples
            num_samples = len(eval_dataset.samples)
            tqdm.write(f"Processing {num_samples} samples for trial {trial}")
            
            # Get the current methods (might be o1-patched or original)
            if wrapped_model:
                current_generate = wrapped_model.generate
                current_agenerate_text = getattr(wrapped_model, 'agenerate_text', None)
            else:
                current_generate = None
                current_agenerate_text = None
            
            # Patch generate method
            def patched_generate(*args, **kwargs):
                # Call current method (original or o1-patched)
                result = current_generate(*args, **kwargs)
                return result
            
            # Patch agenerate_text method (used by RAGAS metrics)
            async def patched_agenerate_text(*args, **kwargs):
                # Call current method (original or o1-patched)
                result = await current_agenerate_text(*args, **kwargs)
                return result
            
            # Apply the patches
            if wrapped_model:
                wrapped_model.generate = patched_generate
                if current_agenerate_text:
                    wrapped_model.agenerate_text = patched_agenerate_text
            
            # Token usage parser removed
            
            # Determine number of workers to use
            # Get from model config (always available since we check earlier)
            num_workers = model_config.num_workers
            
            tqdm.write(f"Using {num_workers} workers for evaluation")
            
            # Create run configuration
            run_config = RunConfig(max_workers=num_workers)
            
            # Create metric for evaluation
            metric = AnswerAccuracy()
            
            # Create evaluation config
            eval_config = {
                "dataset": eval_dataset,
                "metrics": [metric],
                "llm": wrapped_model,
                "show_progress": True,
                "run_config": run_config
            }
            
            # Run evaluation
            start_time = time.time()
            results = evaluate(**eval_config).to_pandas()
            end_time = time.time()
            evaluation_time = end_time - start_time
            num_samples = len(results)
            
            # Save trial results to JSON immediately
            results_dict = results.to_dict(orient='records')
            with open(checkpoint_path, 'w') as f:
                json.dump(results_dict, f)
                
            tqdm.write(f"\nEvaluation for {judge} trial {trial} completed")
            tqdm.write(f"  Evaluation time: {evaluation_time:.2f} seconds")
            tqdm.write(f"  Samples per second: {num_samples / evaluation_time:.2f}")
            tqdm.write(f"  Workers: {num_workers}")


def main():
    """Main function to run the Judge's Verdict scoring pipeline."""
    # Parse command-line arguments
    parser = argparse.ArgumentParser(description='Run Judge\'s Verdict scoring on evaluation data')
    # Token tracking argument removed
    parser.add_argument(
        '--evaluation-file',
        type=str,
        default='./data/evaluation.json',
        help='Path to the evaluation JSON file'
    )
    parser.add_argument(
        '--output-dir',
        type=str,
        default='./results/',
        help='Directory to save judge scoring results'
    )
    parser.add_argument(
        '--judges',
        type=str,
        nargs='+',
        help='List of judge models to use (space-separated). If not provided, uses default list.'
    )
    parser.add_argument(
        '--max-trials',
        type=int,
        default=3,
        help='Maximum number of trials to run per judge (default: 3)'
    )
    parser.add_argument(
        '--default-workers',
        type=int,
        default=16,
        help='Default number of workers for judges not in the config (default: 16)'
    )
    parser.add_argument(
        '--config-file',
        type=str,
        default=None,
        help='Path to YAML configuration file for judge models'
    )

    
    args = parser.parse_args()
    
    # Load configuration manager first to determine which frameworks are in use
    if args.config_file:
        config_manager = JudgeConfigManager(args.config_file)
    else:
        config_manager = JudgeConfigManager()
    
    # Check for API keys based on frameworks actually in use
    frameworks_in_use = set()
    if args.judges:
        # Check frameworks for specified judges
        for judge in args.judges:
            model_config = config_manager.get_model(judge)
            if model_config:
                frameworks_in_use.add(model_config.framework.value)
    else:
        # Check all configured models
        for model_config in config_manager.models.values():
            frameworks_in_use.add(model_config.framework.value)
    
    print(f"Frameworks in use: {sorted(frameworks_in_use)}")
    missing_vars = []
    
    # Check environment variables based on frameworks in use
    if 'litellm' in frameworks_in_use:
        print("Checking environment variables for liteLLM framework...")
        if "OPENAI_API_KEY" not in os.environ:
            missing_vars.append("OPENAI_API_KEY (for OpenAI models via liteLLM)")
        if "ANTHROPIC_API_KEY" not in os.environ:
            missing_vars.append("ANTHROPIC_API_KEY (for Claude models via liteLLM)")
        if "NVIDIA_NIM_API_KEY" not in os.environ:
            missing_vars.append("NVIDIA_NIM_API_KEY (for NVIDIA models via liteLLM)")
        if "NVIDIA_NIM_API_BASE" not in os.environ:
            missing_vars.append("NVIDIA_NIM_API_BASE (for NVIDIA models via liteLLM)")
    
    if 'nvdev' in frameworks_in_use:
        if "NVIDIA_API_KEY" not in os.environ:
            missing_vars.append("NVIDIA_API_KEY (for NVIDIA Developer models)")
    
    if 'openai' in frameworks_in_use:
        if "OPENAI_API_KEY" not in os.environ:
            missing_vars.append("OPENAI_API_KEY (for OpenAI models)")
    
    if 'oneapi' in frameworks_in_use:
        if "ONE_API_KEY" not in os.environ:
            missing_vars.append("ONE_API_KEY (for OneAPI models)")
    
    if missing_vars:
        print("\nWarning: The following environment variables are not set:")
        for var in missing_vars:
            print(f"  - {var}")
        print("\nPlease set the required environment variables.")
        print("Models requiring missing API keys will be skipped.")
    
    # Configuration manager already loaded above
    
    # Get judges list
    if args.judges:
        judges = args.judges
    else:
        # Use all configured models
        judges = config_manager.list_models()
    
    # Note: judge_workers is now deprecated - workers are defined in YAML config
    judge_workers = None
    
    print(f"Running Judge's Verdict scoring with {len(judges)} judge models")
    print(f"Maximum trials per judge: {args.max_trials}")
    print(f"Output directory: {args.output_dir}")
    # Token tracking removed
    
    # Load and process data
    print(f"\nLoading evaluation data from: {args.evaluation_file}")
    df = load_evaluation_data(args.evaluation_file)
    print(f"Loaded dataset with shape: {df.shape}")
    
    # Process evaluations to get unique items
    df_processed = process_evaluations(df)
    print(f"\nUnique items to score: {len(df_processed)}")
    
    # Create RAGAS dataset
    eval_dataset = create_ragas_dataset(df_processed)
    
    # Determine config path to pass
    if args.config_file:
        config_path_to_use = args.config_file
    else:
        config_path_to_use = None
    
    # Run judge evaluations
    run_judge_evaluations(
        eval_dataset, 
        judges, 
        args.output_dir,
        judge_workers,
        args.default_workers,
        args.max_trials,
        config_path_to_use
    )
    
    print("\nJudge's Verdict scoring complete!")


if __name__ == "__main__":
    main() 