#!/usr/bin/env python3
# ==============================================================================
# Pass@K, Avg@K, Best@K Evaluation Script
# ==============================================================================
# This script evaluates model generations using:
# - avg@k: Average accuracy of k samples
# - best@k: Whether at least one of k samples is correct (same as pass@1 on k samples)
# - pass@k: Unbiased estimator using combinatorial formula
#
# Pass@k formula: 1 - C(n-c, k) / C(n, k)
# where n = total samples, c = correct samples, k = samples to consider
#
# Supports parallel evaluation using Ray for acceleration.
# ==============================================================================

import argparse
import json
import math
import os
import sys
from collections import defaultdict
from concurrent.futures import ProcessPoolExecutor, as_completed
from datetime import datetime
from typing import List, Tuple, Optional, Dict, Any

import numpy as np
import pandas as pd
from tqdm import tqdm

# Add project root to path
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.dirname(SCRIPT_DIR)
sys.path.insert(0, PROJECT_ROOT)


def compute_combinations(n: int, k: int) -> float:
    """Compute C(n, k) = n! / (k! * (n-k)!)
    
    Uses logarithms to avoid overflow for large numbers.
    Returns 0 if k > n or k < 0.
    """
    if k > n or k < 0:
        return 0.0
    if k == 0 or k == n:
        return 1.0
    
    # Use log to avoid overflow: log(C(n,k)) = log(n!) - log(k!) - log((n-k)!)
    # Using sum of logs: sum(log(i) for i in range(n-k+1, n+1)) - sum(log(i) for i in range(1, k+1))
    log_numerator = sum(math.log(i) for i in range(n - k + 1, n + 1))
    log_denominator = sum(math.log(i) for i in range(1, k + 1))
    
    return math.exp(log_numerator - log_denominator)


def pass_at_k(n: int, c: int, k: int) -> float:
    """
    Compute pass@k using the unbiased estimator.
    
    Formula: pass@k = 1 - C(n-c, k) / C(n, k)
    
    Args:
        n: Total number of samples
        c: Number of correct samples
        k: Number of samples to consider
        
    Returns:
        pass@k probability
    """
    if n < k:
        # Not enough samples, fall back to simple estimation
        return c / n if n > 0 else 0.0
    
    if c == 0:
        return 0.0
    if c >= n:
        return 1.0
    if n - c < k:
        # Not enough incorrect samples to fill k slots, must have at least one correct
        return 1.0
    
    # Use log computation to avoid numerical issues
    # log(C(n-c, k)) - log(C(n, k))
    log_numerator = sum(math.log(i) for i in range(n - c - k + 1, n - c + 1)) if k > 0 else 0
    log_denominator = sum(math.log(i) for i in range(n - k + 1, n + 1)) if k > 0 else 0
    
    ratio = math.exp(log_numerator - log_denominator) if k > 0 else 1.0
    return 1.0 - ratio


def avg_at_k(scores: List[float], k: int) -> float:
    """
    Compute avg@k: average score of first k samples.
    
    If there are fewer than k samples, uses all available samples.
    
    Args:
        scores: List of scores (0 or 1 for binary, or continuous)
        k: Number of samples to consider
        
    Returns:
        Average score
    """
    if len(scores) == 0:
        return 0.0
    k = min(k, len(scores))
    return sum(scores[:k]) / k


def best_at_k(scores: List[float], k: int) -> float:
    """
    Compute best@k: whether at least one of first k samples is correct.
    
    For binary scores (0/1), this is equivalent to max of first k scores.
    
    Args:
        scores: List of scores (typically 0 or 1)
        k: Number of samples to consider
        
    Returns:
        1.0 if any of first k samples is correct, 0.0 otherwise
    """
    if len(scores) == 0:
        return 0.0
    k = min(k, len(scores))
    return float(max(scores[:k]) > 0)


def load_reward_function(reward_fn_path: Optional[str], reward_fn_name: str = "compute_score", quiet: bool = False):
    """Load custom reward function or use default."""
    if reward_fn_path and os.path.exists(reward_fn_path):
        import importlib.util
        spec = importlib.util.spec_from_file_location("custom_module", reward_fn_path)
        module = importlib.util.module_from_spec(spec)
        spec.loader.exec_module(module)
        if hasattr(module, reward_fn_name):
            if not quiet:
                print(f"Using custom reward function '{reward_fn_name}' from '{reward_fn_path}'")
            return getattr(module, reward_fn_name)
    
    # Use default
    from verl.utils.reward_score import default_compute_score
    if not quiet:
        print("Using default compute_score function")
    return default_compute_score


def evaluate_single_problem(
    reward_fn,
    data_source: str,
    responses: List[str],
    ground_truth: str,
    k_values: List[int]
) -> Tuple[List[float], dict]:
    """
    Evaluate a single problem with multiple responses.
    
    Returns:
        Tuple of (all_scores, metrics_dict)
    """
    # Compute scores for all responses
    scores = []
    for response in responses:
        try:
            score = reward_fn(data_source, response, ground_truth)
            if isinstance(score, dict):
                score = score.get('score', score.get('reward', 0))
            scores.append(float(score))
        except Exception as e:
            print(f"Warning: Error computing score: {e}")
            scores.append(0.0)
    
    n = len(scores)
    c = sum(1 for s in scores if s > 0)  # Count correct answers
    
    metrics = {}
    for k in k_values:
        metrics[f'avg@{k}'] = avg_at_k(scores, k)
        metrics[f'best@{k}'] = best_at_k(scores, k)
        metrics[f'pass@{k}'] = pass_at_k(n, c, k)
    
    return scores, metrics


def evaluate_single_item(args_tuple) -> Dict[str, Any]:
    """
    Wrapper function for multiprocessing.
    
    Args:
        args_tuple: (idx, data_source, responses, ground_truth, k_values, reward_fn_path, reward_fn_name)
    
    Returns:
        Dictionary with evaluation results
    """
    idx, data_source, responses, ground_truth, k_values, reward_fn_path, reward_fn_name = args_tuple
    
    # Load reward function (each process needs to load it)
    reward_fn = load_reward_function(reward_fn_path, reward_fn_name, quiet=True)
    
    # Ensure responses is a list
    if isinstance(responses, str):
        responses = [responses]
    elif not isinstance(responses, list):
        responses = list(responses)
    
    # Evaluate
    scores, metrics = evaluate_single_problem(
        reward_fn, data_source, responses, ground_truth, k_values
    )
    
    # Return results
    result = {
        'index': idx,
        'data_source': data_source,
        'n_responses': len(responses),
        'n_correct': sum(1 for s in scores if s > 0),
    }
    result.update(metrics)
    return result


def evaluate_with_ray(
    dataset: pd.DataFrame,
    args,
    reward_fn_path: Optional[str],
    reward_fn_name: str
) -> Tuple[List[Dict], Dict[str, Dict[str, List]]]:
    """
    Evaluate using Ray for parallel processing.
    """
    import ray
    from omegaconf import OmegaConf
    
    # Initialize Ray if needed
    if not ray.is_initialized():
        ray.init()
    
    @ray.remote
    def evaluate_item_ray(idx, data_source, responses, ground_truth, k_values, reward_fn_path, reward_fn_name):
        return evaluate_single_item((idx, data_source, responses, ground_truth, k_values, reward_fn_path, reward_fn_name))
    
    # Create tasks
    tasks = []
    for idx in range(len(dataset)):
        row = dataset.iloc[idx]
        data_source = row[args.data_source_key]
        responses = row[args.response_key]
        reward_data = row[args.reward_model_key]
        
        # Handle different formats for ground truth
        if isinstance(reward_data, dict):
            ground_truth = reward_data.get("ground_truth", reward_data.get("answer", ""))
        else:
            ground_truth = str(reward_data)
        
        # Ensure responses is a list
        if isinstance(responses, str):
            responses = [responses]
        elif not isinstance(responses, list):
            responses = list(responses)
        
        task = evaluate_item_ray.remote(
            idx, data_source, responses, ground_truth, 
            args.k, reward_fn_path, reward_fn_name
        )
        tasks.append(task)
    
    # Collect results with progress bar
    all_results = []
    data_source_metrics = defaultdict(lambda: defaultdict(list))
    
    with tqdm(total=len(tasks), desc="Evaluating") as pbar:
        while tasks:
            done, tasks = ray.wait(tasks)
            for result_ref in done:
                result = ray.get(result_ref)
                all_results.append(result)
                
                # Aggregate by data source
                data_source = result['data_source']
                for key, value in result.items():
                    if key.startswith(('avg@', 'best@', 'pass@')):
                        data_source_metrics[data_source][key].append(value)
                
                pbar.update(1)
    
    return all_results, data_source_metrics


def evaluate_with_multiprocessing(
    dataset: pd.DataFrame,
    args,
    reward_fn_path: Optional[str],
    reward_fn_name: str,
    num_workers: int = None
) -> Tuple[List[Dict], Dict[str, Dict[str, List]]]:
    """
    Evaluate using multiprocessing for parallel processing.
    """
    import multiprocessing
    
    if num_workers is None:
        num_workers = min(multiprocessing.cpu_count(), 32)
    
    # Prepare tasks
    tasks = []
    for idx in range(len(dataset)):
        row = dataset.iloc[idx]
        data_source = row[args.data_source_key]
        responses = row[args.response_key]
        reward_data = row[args.reward_model_key]
        
        # Handle different formats for ground truth
        if isinstance(reward_data, dict):
            ground_truth = reward_data.get("ground_truth", reward_data.get("answer", ""))
        else:
            ground_truth = str(reward_data)
        
        tasks.append((idx, data_source, responses, ground_truth, args.k, reward_fn_path, reward_fn_name))
    
    # Process with multiprocessing
    all_results = []
    data_source_metrics = defaultdict(lambda: defaultdict(list))
    
    with ProcessPoolExecutor(max_workers=num_workers) as executor:
        futures = {executor.submit(evaluate_single_item, task): task[0] for task in tasks}
        
        with tqdm(total=len(futures), desc="Evaluating") as pbar:
            for future in as_completed(futures):
                result = future.result()
                all_results.append(result)
                
                # Aggregate by data source
                data_source = result['data_source']
                for key, value in result.items():
                    if key.startswith(('avg@', 'best@', 'pass@')):
                        data_source_metrics[data_source][key].append(value)
                
                pbar.update(1)
    
    return all_results, data_source_metrics


def main():
    parser = argparse.ArgumentParser(
        description="Evaluate model generations with Pass@K, Avg@K, and Best@K metrics",
        formatter_class=argparse.RawDescriptionHelpFormatter,
        epilog="""
Examples:
  # Evaluate with k=1,4,8
  python eval_pass_at_k.py --data_path results.parquet --k 1 4 8
  
  # Use custom reward function
  python eval_pass_at_k.py --data_path results.parquet --k 1 8 --reward_fn_path my_reward.py
  
  # Specify column names
  python eval_pass_at_k.py --data_path results.parquet --k 1 8 \\
      --response_key generations --data_source_key source
        """
    )
    
    parser.add_argument(
        "--data_path", "-d",
        type=str,
        required=True,
        help="Path to the parquet file containing generations"
    )
    parser.add_argument(
        "--k", "-k",
        type=int,
        nargs="+",
        default=[1, 4, 8],
        help="Values of k to compute metrics for (default: 1 4 8)"
    )
    parser.add_argument(
        "--response_key",
        type=str,
        default="responses",
        help="Column name for responses (default: responses)"
    )
    parser.add_argument(
        "--data_source_key",
        type=str,
        default="data_source",
        help="Column name for data source (default: data_source)"
    )
    parser.add_argument(
        "--reward_model_key",
        type=str,
        default="reward_model",
        help="Column name for reward model data containing ground_truth (default: reward_model)"
    )
    parser.add_argument(
        "--reward_fn_path",
        type=str,
        default=None,
        help="Path to custom reward function file (optional)"
    )
    parser.add_argument(
        "--reward_fn_name",
        type=str,
        default="compute_score",
        help="Name of reward function in custom file (default: compute_score)"
    )
    parser.add_argument(
        "--output_path", "-o",
        type=str,
        default=None,
        help="Path to save detailed results as CSV (optional)"
    )
    parser.add_argument(
        "--verbose", "-v",
        action="store_true",
        help="Print detailed per-problem results"
    )
    parser.add_argument(
        "--parallel", "-p",
        type=str,
        choices=["none", "multiprocessing", "ray"],
        default="multiprocessing",
        help="Parallelization method (default: multiprocessing)"
    )
    parser.add_argument(
        "--num_workers", "-w",
        type=int,
        default=None,
        help="Number of workers for parallel processing (default: auto)"
    )
    
    args = parser.parse_args()
    
    # Validate k values
    args.k = sorted(set(args.k))
    print(f"Computing metrics for k = {args.k}")
    
    # Load data
    print(f"Loading data from {args.data_path}...")
    if args.data_path.endswith('.parquet'):
        dataset = pd.read_parquet(args.data_path)
    elif args.data_path.endswith('.csv'):
        dataset = pd.read_csv(args.data_path)
    elif args.data_path.endswith('.jsonl'):
        dataset = pd.read_json(args.data_path, lines=True)
    else:
        raise ValueError(f"Unsupported file format: {args.data_path}")
    
    print(f"Loaded {len(dataset)} problems")
    
    # Check columns exist
    required_cols = [args.response_key, args.data_source_key, args.reward_model_key]
    for col in required_cols:
        if col not in dataset.columns:
            print(f"Error: Column '{col}' not found in dataset")
            print(f"Available columns: {list(dataset.columns)}")
            sys.exit(1)
    
    # Load reward function
    reward_fn = load_reward_function(args.reward_fn_path, args.reward_fn_name)
    
    # Group results by data source
    data_source_metrics = defaultdict(lambda: defaultdict(list))
    all_results = []
    
    # Choose evaluation method
    if args.parallel == "ray":
        print(f"Using Ray for parallel evaluation...")
        all_results, data_source_metrics = evaluate_with_ray(
            dataset, args, args.reward_fn_path, args.reward_fn_name
        )
    elif args.parallel == "multiprocessing":
        print(f"Using multiprocessing for parallel evaluation (workers: {args.num_workers or 'auto'})...")
        all_results, data_source_metrics = evaluate_with_multiprocessing(
            dataset, args, args.reward_fn_path, args.reward_fn_name, args.num_workers
        )
    else:
        # Sequential evaluation
        print("Evaluating sequentially...")
        for idx in tqdm(range(len(dataset))):
            row = dataset.iloc[idx]
            
            data_source = row[args.data_source_key]
            responses = row[args.response_key]
            reward_data = row[args.reward_model_key]
            
            # Handle different formats for ground truth
            if isinstance(reward_data, dict):
                ground_truth = reward_data.get("ground_truth", reward_data.get("answer", ""))
            else:
                ground_truth = str(reward_data)
            
            # Ensure responses is a list
            if isinstance(responses, str):
                responses = [responses]
            elif not isinstance(responses, list):
                responses = list(responses)
            
            # Evaluate
            scores, metrics = evaluate_single_problem(
                reward_fn, data_source, responses, ground_truth, args.k
            )
            
            # Store results
            result = {
                'index': idx,
                'data_source': data_source,
                'n_responses': len(responses),
                'n_correct': sum(1 for s in scores if s > 0),
            }
            result.update(metrics)
            all_results.append(result)
            
            # Aggregate by data source
            for metric_name, value in metrics.items():
                data_source_metrics[data_source][metric_name].append(value)
            
            if args.verbose:
                print(f"\nProblem {idx} ({data_source}):")
                print(f"  Responses: {len(responses)}, Correct: {sum(1 for s in scores if s > 0)}")
                for metric_name, value in metrics.items():
                    print(f"  {metric_name}: {value:.4f}")
    
    # Print summary
    print("\n" + "=" * 70)
    print("EVALUATION RESULTS")
    print("=" * 70)
    
    # Per data source results
    for data_source in sorted(data_source_metrics.keys()):
        metrics = data_source_metrics[data_source]
        n_problems = len(metrics[f'avg@{args.k[0]}'])
        
        print(f"\n{data_source} ({n_problems} problems):")
        print("-" * 50)
        
        for k in args.k:
            avg_k = np.mean(metrics[f'avg@{k}'])
            best_k = np.mean(metrics[f'best@{k}'])
            pass_k = np.mean(metrics[f'pass@{k}'])
            
            print(f"  k={k}:  avg@{k}={avg_k:.4f}  best@{k}={best_k:.4f}  pass@{k}={pass_k:.4f}")
    
    # Overall results
    print("\n" + "=" * 70)
    print("OVERALL RESULTS")
    print("-" * 50)
    
    overall_metrics = defaultdict(list)
    for data_source, metrics in data_source_metrics.items():
        for metric_name, values in metrics.items():
            overall_metrics[metric_name].extend(values)
    
    for k in args.k:
        avg_k = np.mean(overall_metrics[f'avg@{k}'])
        best_k = np.mean(overall_metrics[f'best@{k}'])
        pass_k = np.mean(overall_metrics[f'pass@{k}'])
        
        print(f"  k={k}:  avg@{k}={avg_k:.4f}  best@{k}={best_k:.4f}  pass@{k}={pass_k:.4f}")
    
    print("=" * 70)
    
    # Save detailed results
    if args.output_path:
        results_df = pd.DataFrame(all_results)
        if args.output_path.endswith('.csv'):
            results_df.to_csv(args.output_path, index=False)
        else:
            results_df.to_parquet(args.output_path)
        print(f"\nDetailed results saved to {args.output_path}")
    
    # Save summary JSON
    summary = {
        "timestamp": datetime.now().isoformat(),
        "data_path": args.data_path,
        "k_values": args.k,
        "total_problems": len(all_results),
        "per_data_source": {},
        "overall": {}
    }
    
    # Per data source summary
    for data_source in sorted(data_source_metrics.keys()):
        metrics = data_source_metrics[data_source]
        n_problems = len(metrics[f'avg@{args.k[0]}'])
        
        summary["per_data_source"][data_source] = {
            "n_problems": n_problems,
            "metrics": {}
        }
        
        for k in args.k:
            summary["per_data_source"][data_source]["metrics"][f"k={k}"] = {
                f"avg@{k}": float(np.mean(metrics[f'avg@{k}'])),
                f"best@{k}": float(np.mean(metrics[f'best@{k}'])),
                f"pass@{k}": float(np.mean(metrics[f'pass@{k}']))
            }
    
    # Overall summary
    for k in args.k:
        summary["overall"][f"k={k}"] = {
            f"avg@{k}": float(np.mean(overall_metrics[f'avg@{k}'])),
            f"best@{k}": float(np.mean(overall_metrics[f'best@{k}'])),
            f"pass@{k}": float(np.mean(overall_metrics[f'pass@{k}']))
        }
    
    # Save JSON summary
    if args.output_path:
        json_output_path = args.output_path.rsplit('.', 1)[0] + '_summary.json'
    else:
        json_output_path = args.data_path.rsplit('.', 1)[0] + '_eval_summary.json'
    
    with open(json_output_path, 'w', encoding='utf-8') as f:
        json.dump(summary, f, indent=2, ensure_ascii=False)
    print(f"Summary JSON saved to {json_output_path}")
    
    return overall_metrics, summary


if __name__ == "__main__":
    main()
