# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""
Utility functions for processing LLM judge scores.

Provides consistent 3-point conversion and aggregation logic across all scripts:
- For 3 trials: Use majority vote + 3-point conversion
- Otherwise: Use average + 3-point conversion

Supports two approaches:
1. Convert-first (recommended): Convert each score to 3-point → Aggregate
2. Aggregate-first: Aggregate raw scores → Convert to 3-point
"""
from collections import Counter
from typing import List

import numpy as np


def convert_to_3_point_scale(score: float) -> float:
    """
    Convert a score to 3-point scale to match human scoring.

    Conversion rules:
    - score is None → None (pass through)
    - score <= 0.25 → 0.0 (incorrect)
    - 0.25 < score < 0.75 → 0.5 (partially correct)
    - score >= 0.75 → 1.0 (correct)

    Args:
        score: Score in [0, 1] range (can be None)

    Returns:
        Converted score: 0.0, 0.5, 1.0, or None
    """
    if score is None:
        return None

    # Convert string to float if needed
    if isinstance(score, str):
        try:
            score = float(score)
        except (ValueError, TypeError):
            return None

    if score <= 0.25:
        return 0.0
    elif score < 0.75:
        return 0.5
    else:
        return 1.0


def majority_vote(scores: List[float]) -> float:
    """
    Compute majority vote from a list of scores.

    Args:
        scores: List of scores

    Returns:
        The most frequent score
    """
    if not scores:
        return None  # Return None if no scores instead of default 0.5

    # Count occurrences
    counts = Counter(scores)
    # Get the most common score
    most_common = counts.most_common(1)[0][0]
    return most_common


def aggregate_llm_scores(
    trial_scores: List[float], convert_to_3_point: bool = True, convert_first: bool = False
) -> float:
    """
    Aggregate multiple trial scores using the optimal strategy.

    Strategy:
    - For exactly 3 trials: Use majority vote (no ties possible with 3 trials)
    - Otherwise: Use averaging (more robust for variable trial counts)

    Args:
        trial_scores: List of scores from multiple trials
        convert_to_3_point: Whether to apply 3-point conversion
        convert_first: If True, convert each score to 3-point before aggregating.
                      If False (default), aggregate first then convert.

    Returns:
        Aggregated score (optionally converted to 3-point scale), or None if no valid scores
    """
    if not trial_scores:
        return None

    # Filter out None values from input
    valid_scores = [s for s in trial_scores if s is not None]
    if not valid_scores:
        return None  # No valid scores to aggregate

    if convert_first and convert_to_3_point:
        # Convert each score to 3-point scale first
        converted_scores = [convert_to_3_point_scale(s) for s in valid_scores]
        # Filter out None values after conversion
        converted_scores = [s for s in converted_scores if s is not None]

        if not converted_scores:
            return None  # All conversions resulted in None

        # Then aggregate
        if len(converted_scores) == 3:
            return majority_vote(converted_scores)
        else:
            return np.mean(converted_scores)
    else:
        # Original approach: aggregate first
        if len(valid_scores) == 3:
            # For exactly 3 trials, use majority vote
            aggregated_score = majority_vote(valid_scores)
        else:
            # For any other number of trials, use averaging
            aggregated_score = np.mean(valid_scores)

        # Then convert if requested
        if convert_to_3_point:
            return convert_to_3_point_scale(aggregated_score)
        else:
            return aggregated_score
