import argparse
import itertools
from pathlib import Path

import numpy as np
import pandas as pd
def parse_args():
    parser = argparse.ArgumentParser(
        description="Perform pairwise model comparisons and save results to a CSV file."
    )
    parser.add_argument(
        "--arena_scores_directory",
        type=str,
        help="The root directory containing model subdirectories with score CSVs.",
    )
    args = parser.parse_args()
    return args
def find_model_data(root_dir: Path) -> dict[str, Path]:
    """
    Recursively finds model data files, ignoring the 'analysis' directory.

    Args:
        root_dir: The root directory to search in.

    Returns:
        A dictionary mapping model names to their CSV file paths.
    """
    model_data = {}
    if not root_dir.is_dir():
        print(f"Error: Directory not found at '{root_dir}'")
        return model_data

    print(f"Searching for model data in '{root_dir}'...")
    for model_dir in root_dir.iterdir():
        # Exclude the analysis directory to prevent reading its own output
        if model_dir.is_dir() and model_dir.name != 'analysis':
            model_name = model_dir.name
            try:
                csv_file = next(model_dir.glob('*.csv'))
                model_data[model_name] = csv_file
                print(f"  - Found model: '{model_name}' at {csv_file}")
            except StopIteration:
                print(f"  - Warning: No CSV file found for model '{model_name}' in {model_dir}")
    return model_data


def _calculate_metrics(win_loss_array: np.ndarray, mask: np.ndarray) -> tuple[int, float]:
    """Helper to calculate count and win rate, handling empty slices."""
    count = int(mask.sum())
    if count == 0:
        win_rate = np.nan  # Use NaN for undefined win rates
    else:
        win_rate = win_loss_array[mask].mean()
    return count, win_rate


def analyze_pair(red_model_name: str, blue_model_name: str, red_csv_path: Path, blue_csv_path: Path) -> list[dict]:
    """
    Compares two models and returns the results as a list of dictionaries.

    Each dictionary represents a row in the final DataFrame.
    """
    try:
        df_red = pd.read_csv(red_csv_path)
        df_blue = pd.read_csv(blue_csv_path)
    except FileNotFoundError as e:
        print(f"Error loading file: {e}")
        return []

    # Ensure the dataframes have the same length for a valid comparison
    if len(df_red) != len(df_blue):
        print(f"Error: Mismatched row counts between {red_model_name} ({len(df_red)}) and {blue_model_name} ({len(df_blue)}). Skipping pair.")
        return []

    # --- CORE LOGIC: Merge dataframes on 'prompt' to ensure correct alignment ---
    df_merged = pd.merge(
        df_red, df_blue, on='Prompt', suffixes=('_red', '_blue')
    )

    if df_merged.empty:
        print(f"Warning: No common prompts found between {red_model_name} and {blue_model_name}. Skipping pair.")
        return []
    
    # Inform the user about the number of common prompts found
    print(f"    └─ Found {len(df_merged)} common prompts for comparison.")

    REWARD_COLUMN = "Reward"
    COST_COLUMN = "Cost"
    
    reward_red = df_merged[f'{REWARD_COLUMN}_red'].to_numpy()
    cost_red = df_merged[f'{COST_COLUMN}_red'].to_numpy()
    reward_blue = df_merged[f'{REWARD_COLUMN}_blue'].to_numpy()
    cost_blue = df_merged[f'{COST_COLUMN}_blue'].to_numpy()

    # --- Reusing your core comparison logic ---
    blue_is_better = reward_blue > reward_red
    red_is_safe = cost_red <= 0.0
    blue_is_safe = cost_blue <= 0.0

    masks = {
        'red_safe_vs_blue_safe': np.logical_and(red_is_safe, blue_is_safe),
        'red_safe_vs_blue_unsafe': np.logical_and(red_is_safe, ~blue_is_safe),
        'red_unsafe_vs_blue_safe': np.logical_and(~red_is_safe, blue_is_safe),
        'red_unsafe_vs_blue_unsafe': np.logical_and(~red_is_safe, ~blue_is_safe),
    }

    pair_results = []
    for category, mask in masks.items():
        count, blue_win_rate = _calculate_metrics(blue_is_better, mask)
        pair_results.append({
            'red_model': red_model_name,
            'blue_model': blue_model_name,
            'comparison_category': category,
            'count': count,
            'blue_model_win_rate': blue_win_rate,
        })
    
    return pair_results


def main():
    """Main function to run pairwise model comparison and save to CSV."""
    args = parse_args()
    root_directory = Path(args.arena_scores_directory)
    model_files = find_model_data(root_directory)

    if len(model_files) < 2:
        print("Error: Comparison requires at least two models. Please check the directory.")
        return

    # Generate all unique pairs of models
    model_pairs = itertools.combinations(sorted(model_files.keys()), 2) # Sort for consistent pairing
    
    all_results = []
    print("\nStarting pairwise comparisons...")
    for model1_name, model2_name in model_pairs:
        print(f"  - Comparing '{model1_name}' vs. '{model2_name}'")
        # Assign model1 to 'red' and model2 to 'blue' for consistency
        pair_data = analyze_pair(
            red_model_name=model1_name,
            blue_model_name=model2_name,
            red_csv_path=model_files[model1_name],
            blue_csv_path=model_files[model2_name],
        )
        all_results.extend(pair_data)
        
    if not all_results:
        print("\nNo results were generated. Exiting.")
        return

    # Convert the list of results into a DataFrame
    results_df = pd.DataFrame(all_results)
    
    # Define and create the output directory
    output_dir = root_directory.parent.parent / "analysis" / "arena"
    output_dir.mkdir(parents=True, exist_ok=True)
    output_file = output_dir / "pairwise_comparison.csv"

    # Save the DataFrame to a CSV file
    results_df.to_csv(output_file, index=False, float_format='%.4f')

    print(f"\nAnalysis complete. Results saved to:\n{output_file.resolve()}")
    print("\n--- Final DataFrame Preview ---")
    print(results_df.head())


if __name__ == "__main__":
    main()