"""Usage:
python rejudge.py --input_csv eval_result.csv --questions ../evaluation/first_plot_questions.yaml --output rejudged_result.csv

# With custom judge model:
python rejudge.py --input_csv eval_result.csv --questions ../evaluation/first_plot_questions.yaml --judge_model gpt-4o-mini

This script takes an existing CSV file with questions and answers, and recomputes
the judge scores using LLM judges without running inference again.
"""

import asyncio
import logging
import os
import yaml
import json
import pandas as pd
from typing import Optional

from judge import OpenAiJudge


class Question:
    def __init__(
        self,
        id: str,
        paraphrases: Optional[list[str]] = None,
        judge_prompts: Optional[dict] = None,
        temperature: float = 1,
        system: str = None,
        judge: str = "gpt-4o",
        **ignored_extra_args,
    ):
        self.id = id
        self.paraphrases = paraphrases
        self.temperature = temperature
        self.system = system
        self.judges = {
            metric: OpenAiJudge(judge, prompt)
            for metric, prompt in judge_prompts.items()
        }

    async def rejudge_rows(self, rows_df):
        """
        Re-evaluate existing question-answer pairs with the judges.
        
        Args:
            rows_df: DataFrame containing rows with 'question' and 'answer' columns
                     for this specific question_id
        
        Returns:
            DataFrame with judge scores added as new columns
        """
        questions = rows_df['question'].tolist()
        answers = rows_df['answer'].tolist()
        
        result_df = rows_df.copy()
        
        for score, judge in self.judges.items():
            scores = await asyncio.gather(
                *[
                    judge(question=question, answer=answer)
                    for question, answer in zip(questions, answers)
                ]
            )
            result_df[score] = scores
        
        return result_df


def load_questions(path, judge_model_override=None):
    """Load questions from YAML or JSONL file.
    
    Args:
        path: Path to the questions file (YAML or JSONL)
        judge_model_override: Optional model name to override the judge model
                             specified in the questions file
    """
    questions = []
    with open(path, "r") as f:
        if path.endswith(".jsonl"):
            data = [json.loads(line) for line in f]
            is_json = True
        else:
            data = yaml.load(f, Loader=yaml.SafeLoader)
            is_json = False
    for i, question in enumerate(data):
        if is_json:
            questions.append(
                Question(
                    id=f"{path}_{i}",
                    paraphrases=[question["messages"][0]["content"]],
                    judge_prompts={},
                )
            )
        else:
            question_data = dict(question)
            # Override the judge model if specified
            if judge_model_override is not None:
                question_data['judge'] = judge_model_override
            # Convert id to string to match CSV format (CSV stores question_id as strings)
            if 'id' in question_data:
                question_data['id'] = str(question_data['id'])
            questions.append(Question(**question_data))
    return questions


def main(
    input_csv,
    questions,
    output=None,
    judge_model=None,
):
    logging.basicConfig(level=logging.INFO)
    """Re-evaluate existing CSV with LLM judges
    
    Args:
        input_csv: Path to the existing CSV file with questions and answers
        questions: Path to the questions YAML/JSONL file with judge definitions
        output: Path to save the rejudged results (default: input_csv with _rejudged suffix)
        judge_model: Optional model name to override the judge model specified in questions file
                     (e.g., 'gpt-4o-mini', 'gpt-4-turbo', 'gpt-3.5-turbo')
    """
    
    # Set output path if not provided
    if output is None:
        base_name = os.path.splitext(input_csv)[0]
        output = f"{base_name}_rejudged.csv"
    
    # Load the existing CSV
    logging.info(f"Loading existing results from {input_csv}")
    df = pd.read_csv(input_csv)
    
    # Verify required columns exist
    required_cols = ['question', 'answer', 'question_id']
    missing_cols = [col for col in required_cols if col not in df.columns]
    if missing_cols:
        raise ValueError(f"Input CSV is missing required columns: {missing_cols}")
    
    logging.info(f"Loaded {len(df)} rows from CSV")
    
    # Load questions to get judge prompts
    logging.info(f"Loading question definitions from {questions}")
    if judge_model:
        logging.info(f"Overriding judge model with: {judge_model}")
    questions_list = load_questions(questions, judge_model_override=judge_model)
    
    # Create a mapping of question_id to Question object
    # Use string keys to match CSV format
    question_map = {str(q.id): q for q in questions_list}
    
    # Also ensure question_id in DataFrame is string type
    df['question_id'] = df['question_id'].astype(str)
    
    # Log available question IDs for debugging
    logging.info(f"Loaded {len(question_map)} questions from {questions}")
    logging.info(f"Question IDs in YAML: {sorted(list(question_map.keys())[:10])}{'...' if len(question_map) > 10 else ''}")
    
    unique_csv_ids = df['question_id'].unique()
    logging.info(f"Unique question IDs in CSV: {sorted(list(unique_csv_ids)[:10])}{'...' if len(unique_csv_ids) > 10 else ''}")
    
    # Group by question_id
    grouped = df.groupby('question_id')
    
    outputs = []
    total_groups = len(grouped)
    
    for idx, (question_id, group_df) in enumerate(grouped, 1):
        logging.info(f"Processing question {idx}/{total_groups}: {question_id}")
        
        # Find the corresponding Question object
        if question_id not in question_map:
            logging.warning(f"Question ID '{question_id}' not found in questions file. Skipping.")
            # Keep the original data without rejudging
            outputs.append(group_df)
            continue
        
        question = question_map[question_id]
        
        # Check if this question has any judges
        if not question.judges:
            logging.warning(f"Question ID '{question_id}' has no judges defined. Skipping.")
            outputs.append(group_df)
            continue
        
        # Re-evaluate this group
        rejudged_df = asyncio.run(question.rejudge_rows(group_df))
        outputs.append(rejudged_df)
    
    # Combine all results
    final_df = pd.concat(outputs, ignore_index=True)
    
    # Save to output CSV
    logging.info(f"Saving rejudged results to {output}")
    final_df.to_csv(output, index=False)
    logging.info(f"Done! Rejudged {len(final_df)} rows.")
    
    # Print summary of judge columns
    judge_cols = [col for col in final_df.columns if col not in ['question', 'answer', 'question_id']]
    if judge_cols:
        logging.info(f"Judge columns in output: {', '.join(judge_cols)}")


if __name__ == "__main__":
    import fire

    fire.Fire(main)

