#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Script to check the percentage of NaN scores for each LLM judge.
"""

import json
import math
from collections import defaultdict
from pathlib import Path
from typing import Tuple

import pandas as pd


def is_nan_score(score):
    """Check if a score is NaN."""
    if score is None:
        return True
    if isinstance(score, str):
        return score.lower() in ["nan", "null", "none", ""]
    if isinstance(score, (int, float)):
        return math.isnan(score)
    return False


def extract_judge_name(filename: str) -> str:
    """Extract judge name from filename."""
    # Remove .json extension and trial suffix
    name = filename.replace(".json", "")
    if "_trial" in name:
        name = name.split("_trial")[0]
    return name


def analyze_judge_file(filepath: Path) -> Tuple[int, int]:
    """
    Analyze a single judge file and return (total_scores, nan_count).
    """
    total_scores = 0
    nan_count = 0

    try:
        with open(filepath, "r") as f:
            data = json.load(f)

        # Handle different possible data structures
        if isinstance(data, list):
            # If data is a list of items
            for item in data:
                if isinstance(item, dict):
                    # Check for score fields - updated to include nv_accuracy
                    score_fields = ["nv_accuracy", "score", "llm_score", "judge_score", "rating", "evaluation_score"]
                    for field in score_fields:
                        if field in item:
                            total_scores += 1
                            if is_nan_score(item[field]):
                                nan_count += 1
                            break
        elif isinstance(data, dict):
            # If data is a dictionary, check if it has a 'results' or 'data' key
            results_key = None
            for key in ["results", "data", "evaluations", "scores"]:
                if key in data and isinstance(data[key], list):
                    results_key = key
                    break

            if results_key:
                for item in data[results_key]:
                    if isinstance(item, dict):
                        score_fields = [
                            "nv_accuracy",
                            "score",
                            "llm_score",
                            "judge_score",
                            "rating",
                            "evaluation_score",
                        ]
                        for field in score_fields:
                            if field in item:
                                total_scores += 1
                                if is_nan_score(item[field]):
                                    nan_count += 1
                                break
            else:
                # Check if scores are directly in the dict
                for key, value in data.items():
                    if "score" in key.lower() or key == "nv_accuracy":
                        if not key.endswith("_tokens"):
                            total_scores += 1
                            if is_nan_score(value):
                                nan_count += 1

    except Exception as e:
        print(f"Error processing {filepath}: {e}")
        return 0, 0

    return total_scores, nan_count


def main():
    """Main function to analyze all judge files."""
    benchmark_dir = Path("benchmark/judge_results")

    if not benchmark_dir.exists():
        print(f"Error: Directory {benchmark_dir} does not exist!")
        return

    # Dictionary to store results by judge
    # Structure: {judge_name: {trial_num: {'total': x, 'nan': y}}}
    judge_stats = defaultdict(lambda: defaultdict(lambda: {"total": 0, "nan": 0}))

    # Process each judge directory
    for judge_dir in benchmark_dir.iterdir():
        if not judge_dir.is_dir():
            continue

        judge_name = judge_dir.name

        # Process trial files in each judge directory
        for trial_file in judge_dir.glob("trial*.json"):
            trial_num = trial_file.stem  # e.g., "trial1", "trial2", "trial3"

            print(f"Processing {judge_name}/{trial_file.name}...")
            total, nan = analyze_judge_file(trial_file)

            judge_stats[judge_name][trial_num]["total"] = total
            judge_stats[judge_name][trial_num]["nan"] = nan

    # Create results dataframe with per-trial percentages
    results = []
    for judge, trials in sorted(judge_stats.items()):
        # Calculate overall statistics
        overall_total = sum(stats["total"] for stats in trials.values())
        overall_nan = sum(stats["nan"] for stats in trials.values())
        overall_percentage = (overall_nan / overall_total * 100) if overall_total > 0 else 0.0

        # Create per-trial percentage strings
        trial_percentages = {}
        for trial_num in ["trial1", "trial2", "trial3"]:
            if trial_num in trials and trials[trial_num]["total"] > 0:
                trial_nan_pct = (trials[trial_num]["nan"] / trials[trial_num]["total"]) * 100
                trial_percentages[trial_num] = f"{trial_nan_pct:.1f}%"
            else:
                trial_percentages[trial_num] = "N/A"

        results.append(
            {
                "Judge": judge,
                "Trial 1 NaN%": trial_percentages.get("trial1", "N/A"),
                "Trial 2 NaN%": trial_percentages.get("trial2", "N/A"),
                "Trial 3 NaN%": trial_percentages.get("trial3", "N/A"),
                "Overall NaN%": f"{overall_percentage:.2f}%",
                "Total Scores": overall_total,
                "Total NaN": overall_nan,
            }
        )

    # Display results
    df = pd.DataFrame(results)

    if len(df) > 0:
        # Sort by Overall NaN percentage (descending)
        df["_sort_key"] = df["Overall NaN%"].str.rstrip("%").astype(float)
        df = df.sort_values("_sort_key", ascending=False).drop("_sort_key", axis=1)

        print("\n" + "=" * 100)
        print("LLM Judge NaN Score Analysis (Per-Trial Breakdown)")
        print("=" * 100)
        print(df.to_string(index=False))

        # Summary statistics
        print("\n" + "-" * 100)
        print("Summary Statistics:")
        print("-" * 100)

        total_scores = sum(sum(stats["total"] for stats in trials.values()) for trials in judge_stats.values())
        total_nans = sum(sum(stats["nan"] for stats in trials.values()) for trials in judge_stats.values())
        overall_nan_percentage = (total_nans / total_scores * 100) if total_scores > 0 else 0

        print(f"Total number of judges analyzed: {len(judge_stats)}")
        print(f"Total scores across all judges: {total_scores:,}")
        print(f"Total NaN scores: {total_nans:,}")
        print(f"Overall NaN percentage: {overall_nan_percentage:.2f}%")

        # Save results to CSV
        output_file = benchmark_dir / "nan_score_analysis.csv"
        df.to_csv(output_file, index=False)
        print(f"\nResults saved to: {output_file}")
    else:
        print("No judge score files found!")


if __name__ == "__main__":
    main()
