"""Usage:
python eval_judge.py --input ./output.csv --questions ../evaluation/first_plot_questions.yaml --output ./output_with_scores.csv

This script reads a CSV file with questions and answers (produced by eval_inference.py),
applies the judges to score the answers, and saves the results with score columns added.

This script can be run on CPU nodes as it only does scoring via API calls, no GPU required.
"""

import asyncio
import logging
import os
import yaml
import json
import pandas as pd
from typing import Optional, Dict
from tqdm import tqdm

from judge import OpenAiJudge


class QuestionConfig:
    """Configuration for a question including its judge prompts"""
    def __init__(
        self,
        id: str,
        judge_prompts: Optional[dict] = None,
        judge: str = "gpt-4o",
        **ignored_extra_args,
    ):
        self.id = id
        self.judges = {}
        if judge_prompts:
            self.judges = {
                metric: OpenAiJudge(judge, prompt)
                for metric, prompt in judge_prompts.items()
            }


def load_question_configs(path):
    """Load question configurations including judge prompts"""
    questions = {}
    with open(path, "r") as f:
        if path.endswith(".jsonl"):
            data = [json.loads(line) for line in f]
            is_json = True
        else:
            data = yaml.load(f, Loader=yaml.SafeLoader)
            is_json = False
    
    for i, question in enumerate(data):
        if is_json:
            # JSONL format typically doesn't have judge prompts
            question_id = f"{path}_{i}"
            questions[question_id] = QuestionConfig(
                id=question_id,
                judge_prompts={},
            )
        else:
            # YAML format with full configuration
            q_config = QuestionConfig(**question)
            questions[q_config.id] = q_config
    
    return questions


async def score_row(row, question_config):
    """Score a single row using the question's judges"""
    scores = {}
    
    if not question_config.judges:
        # No judges configured for this question
        return scores
    
    for score_name, judge in question_config.judges.items():
        try:
            score = await judge(question=row['question'], answer=row['answer'])
            scores[score_name] = score
        except Exception as e:
            print(f"Error scoring row for {score_name}: {e}")
            scores[score_name] = None
    
    return scores


async def score_dataframe(df, question_configs):
    """Score all rows in the dataframe"""
    results = []
    
    print(f"Scoring {len(df)} rows...")
    for idx, row in tqdm(df.iterrows(), total=len(df), desc="Scoring"):
        question_id = row['question_id']
        
        if question_id not in question_configs:
            print(f"Warning: No configuration found for question_id {question_id}")
            results.append({})
            continue
        
        question_config = question_configs[question_id]
        scores = await score_row(row, question_config)
        results.append(scores)
    
    return results


def main(
    input: str,
    questions: str,
    output: Optional[str] = None,
    overwrite: bool = False,
):
    """
    Score the answers in a CSV file using judges.
    
    Args:
        input: Path to the CSV file with questions and answers (from eval_inference.py)
        questions: Path to the questions YAML/JSONL file with judge configurations
        output: Path to save the scored results. If None, overwrites the input file.
        overwrite: If True and output is None, will overwrite the input file without confirmation
    """
    logging.basicConfig(level=logging.INFO)
    
    if output is None:
        output = input
        if not overwrite:
            print(f"Warning: Will overwrite {input} with scored results.")
            print("Use --output to specify a different output file, or --overwrite=True to skip this warning.")
    
    # Load the CSV file
    if not os.path.exists(input):
        raise FileNotFoundError(f"Input file not found: {input}")
    
    print(f"Loading data from {input}")
    df = pd.read_csv(input)
    
    if 'question' not in df.columns or 'answer' not in df.columns or 'question_id' not in df.columns:
        raise ValueError("Input CSV must have 'question', 'answer', and 'question_id' columns")
    
    # Load question configurations
    print(f"Loading question configurations from {questions}")
    question_configs = load_question_configs(questions)
    
    # Collect all score column names
    all_score_columns = set()
    for q_config in question_configs.values():
        all_score_columns.update(q_config.judges.keys())
    
    # Check if scores already exist
    existing_scores = set(df.columns) & all_score_columns
    if existing_scores:
        print(f"Found existing score columns: {existing_scores}")
        print("These will be overwritten.")
    
    # Score the dataframe
    scores_list = asyncio.run(score_dataframe(df, question_configs))
    
    # Add scores to the dataframe
    for score_col in all_score_columns:
        df[score_col] = [scores.get(score_col, None) for scores in scores_list]
    
    # Save the results
    df.to_csv(output, index=False)
    print(f"Scored results saved to {output}")
    
    # Print summary statistics
    print("\nScoring Summary:")
    for score_col in sorted(all_score_columns):
        if score_col in df.columns:
            valid_scores = df[score_col].dropna()
            if len(valid_scores) > 0:
                print(f"  {score_col}: mean={valid_scores.mean():.2f}, std={valid_scores.std():.2f}, n={len(valid_scores)}")
            else:
                print(f"  {score_col}: no valid scores")


if __name__ == "__main__":
    import fire

    fire.Fire(main)

