# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Common utilities for loading judge results across multiple trials.

This module provides functions to:
- Detect available judges
- Load judge results from any number of trials
- Create lookup dictionaries for efficient score retrieval
"""

import json
import os
from typing import Dict, List, Tuple

import pandas as pd

from src.llm_score_utils import aggregate_llm_scores


def get_available_judges(results_dir: str) -> List[str]:
    """Get list of judges that have at least one trial result in the results directory."""
    judges = []

    # Look for subdirectories in the results directory
    if os.path.exists(results_dir):
        for judge_dir in os.listdir(results_dir):
            judge_path = os.path.join(results_dir, judge_dir)
            if os.path.isdir(judge_path):
                # Check if at least one trial file exists
                trial1_path = os.path.join(judge_path, "trial1.json")
                if os.path.exists(trial1_path):
                    judges.append(judge_dir)

    return sorted(judges)


def load_judge_trial_data(judge_name: str, results_dir: str) -> Tuple[List[pd.DataFrame], int]:
    """
    Load all available trial data for a judge.

    Args:
        judge_name: Name of the judge (e.g., "nvdev_meta_llama-3.1-70b-instruct")
        results_dir: Base directory containing judge subdirectories

    Returns:
        Tuple of (list of trial DataFrames, number of trials found)
    """
    judge_dir = os.path.join(results_dir, judge_name)

    # Dynamically find all available trial files for this judge
    trial_results = []
    trial_num = 1

    while True:
        trial_path = os.path.join(judge_dir, f"trial{trial_num}.json")
        if os.path.exists(trial_path):
            with open(trial_path, "r") as f:
                data = json.load(f)
                # Convert to DataFrame, handling both list and dict formats
                if isinstance(data, list):
                    trial_results.append(pd.DataFrame(data))
                else:
                    trial_results.append(pd.DataFrame([data]))
            trial_num += 1
        else:
            break

    return trial_results, len(trial_results)


def create_score_lookup(
    results_df: pd.DataFrame,
    question_col: str = "user_input",
    reference_col: str = "reference",
    response_col: str = "response",
    score_col: str = "nv_accuracy",
) -> Dict[Tuple[str, str, str], float]:
    """
    Create lookup dictionary from results dataframe.

    Args:
        results_df: DataFrame with judge results
        question_col: Column name for question/user_input
        reference_col: Column name for reference/ground_truth
        response_col: Column name for model response
        score_col: Column name for score

    Returns:
        Dictionary mapping (question, reference, response) tuples to scores
    """
    lookup = {}
    for _, row in results_df.iterrows():
        # Extract the data - judge data uses user_input/reference/response
        question = row.get(question_col, row.get("question", ""))
        reference = row.get(reference_col, row.get("gt_answer", row.get("ground_truth", "")))
        response = row.get(response_col, row.get("gen_answer", row.get("answer", "")))

        # Create key using the same field order as human annotations
        # Human annotations use (question, gt_answer, gen_answer) as key
        key = (question, reference, response)

        # Try multiple possible score column names
        score = row.get(score_col, row.get("score", row.get("Score", None)))
        lookup[key] = score
    return lookup


def load_and_aggregate_judge_scores(
    judge_name: str, results_dir: str, convert_to_3_point: bool = True, convert_first: bool = True
) -> Dict[Tuple[str, str, str], float]:
    """
    Load all trials for a judge and aggregate the scores.

    Args:
        judge_name: Name of the judge
        results_dir: Base directory containing judge subdirectories
        convert_to_3_point: Whether to convert scores to 3-point scale
        convert_first: Whether to convert to 3-point scale before aggregation

    Returns:
        Dictionary mapping (question, reference, response) tuples to aggregated scores
    """
    trial_dfs, num_trials = load_judge_trial_data(judge_name, results_dir)

    if num_trials == 0:
        return {}

    # Create lookups for all trials
    lookups = [create_score_lookup(trial_df) for trial_df in trial_dfs]

    # Get all unique keys across all trials
    all_keys = set()
    for lookup in lookups:
        all_keys.update(lookup.keys())

    # Aggregate scores for each key
    aggregated_scores = {}
    for key in all_keys:
        # Get scores from all trials for this key
        trial_scores = []
        for lookup in lookups:
            score = lookup.get(key, None)
            if score is not None:
                trial_scores.append(score)

        if trial_scores:
            # Aggregate the scores
            aggregated_score = aggregate_llm_scores(
                trial_scores, convert_to_3_point=convert_to_3_point, convert_first=convert_first
            )
            if aggregated_score is not None:
                aggregated_scores[key] = aggregated_score

    return aggregated_scores


def get_judge_config(judge_name: str, results_dir: str = None) -> Dict:
    """
    Load judge configuration from centralized config.

    Args:
        judge_name: Name of the judge
        results_dir: Deprecated parameter, kept for compatibility

    Returns:
        Dictionary with judge configuration or empty dict if not found
    """
    from src.judge_config import get_judge_model_config

    model_config = get_judge_model_config(judge_name)
    if model_config:
        return model_config.to_dict()
    return {}
