from datasets import load_from_disk, load_dataset
from typing import Dict, Any, List
import numpy as np
from tabulate import tabulate
import sys
from operator import itemgetter
import argparse

def analyze_correct_answers(dataset) -> Dict[str, int]:
    """Analyze the distribution of correct answers in the dataset."""
    total_problems = len(dataset)
    problems_with_correct = 0
    total_correct_answers = 0
    
    for item in dataset:
        correct_count = sum(1 for x in item['answer_correct'] if x)
        if correct_count > 0:
            problems_with_correct += 1
        total_correct_answers += correct_count
    
    return {
        'total_problems': total_problems,
        'problems_with_correct': problems_with_correct,
        'total_correct_answers': total_correct_answers
    }

def preprocess_perm_scores(dataset) -> Any:
    """Preprocess PRM scores by taking the minimum score for each sample."""
    for i in range(len(dataset)):
        # Handle eurus_prm_scores
        if 'eurus_prm_scores' in dataset[i]:
            if isinstance(dataset[i]['eurus_prm_scores'], (list, np.ndarray)):
                dataset[i]['eurus_prm_scores'] = [
                    min(scores) if isinstance(scores, (list, np.ndarray)) else scores
                    for scores in dataset[i]['eurus_prm_scores']
                ]
        
        # Handle eurus_prm2_scores
        if 'eurus_prm2_scores' in dataset[i]:
            if isinstance(dataset[i]['eurus_prm2_scores'], (list, np.ndarray)):
                dataset[i]['eurus_prm2_scores'] = [
                    min(scores) if isinstance(scores, (list, np.ndarray)) else scores
                    for scores in dataset[i]['eurus_prm2_scores']
                ]
    
    return dataset

def parse_args():
    parser = argparse.ArgumentParser(description='Evaluate model selection strategies')
    parser.add_argument('--dataset',
                      type=str,
                      required=True,
                      help='Path to HuggingFace dataset')
    parser.add_argument('--reward_model_threshold_for_accuracy', 
                      type=float, 
                      default=0.7,
                      help='Threshold for reward model score classification')
    return parser.parse_args()

def load_evaluation_dataset(filepath: str):
    """Load and parse the HuggingFace dataset from disk or from HF hub."""
    try:
        # First try loading as a local path
        try:
            dataset = load_from_disk(filepath)
            return dataset
        except FileNotFoundError:
            # If local path fails, try loading from HF hub
            print(f"Local dataset not found, attempting to load from HuggingFace hub: {filepath}")
            dataset = load_dataset(filepath, split='data')
            return dataset
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        sys.exit(1)

def get_top_performers(results: List[Dict[str, Any]], n: int, exclude_aggregated: bool = True) -> List[str]:
    """Get the names of the top n performing models/judges."""
    if exclude_aggregated:
        results = [r for r in results if not ('Aggregated' in r['strategy'] or 'Majority' in r['strategy'])]
    
    sorted_results = sorted(results, key=itemgetter('accuracy'), reverse=True)
    return [r['strategy'].split('Best ')[-1] if 'Best ' in r['strategy'] 
            else r['strategy'].split('First True from ')[-1]
            for r in sorted_results[:n]]

def evaluate_first_sample_strategy(dataset) -> Dict[str, float]:
    """Evaluate strategy of always selecting the first sample."""
    total_problems = len(dataset)
    correct_predictions = 0
    
    for i in range(total_problems):
        if dataset[i]['answer_correct'][0]:
            correct_predictions += 1
    
    accuracy = correct_predictions / total_problems if total_problems > 0 else 0
    return {
        'strategy': 'First Sample',
        'accuracy': accuracy,
        'correct': correct_predictions,
        'total': total_problems
    }

def get_available_reward_models(dataset) -> List[str]:
    """Get list of available reward model score columns in the dataset."""
    return [col for col in dataset.features.keys() 
            if col.endswith('_scores') or col == 'armor_rm_score']

def evaluate_lm_judge_strategies(dataset) -> List[Dict[str, float]]:
    """Evaluate strategies based on different LM judges."""
    judge_columns = [col for col in dataset.features.keys() if col.startswith('judge_')]
    results = []
    total_problems = len(dataset)
    
    # Individual judges first
    for judge in judge_columns:
        correct_predictions = 0
        for i in range(total_problems):
            verdicts = dataset[i][judge]
            if isinstance(verdicts, list):
                try:
                    first_true_idx = next(idx for idx, v in enumerate(verdicts) if v is True)
                    if dataset[i]['answer_correct'][first_true_idx]:
                        correct_predictions += 1
                except StopIteration:
                    pass
        
        accuracy = correct_predictions / total_problems
        results.append({
            'strategy': f'First True from {judge}',
            'accuracy': accuracy,
            'correct': correct_predictions,
            'total': total_problems
        })

    # Get top performers
    top_3_judges = get_top_performers(results, 3)
    top_5_judges = get_top_performers(results, 5)
    
    print("\nTop 3 Judges:", top_3_judges)
    print("Top 5 Judges:", top_5_judges)
    
    # Evaluate voting strategies
    for judges, name in [
        (judge_columns, 'All Judges'),
        (top_3_judges, 'Top 3 Judges'),
        (top_5_judges, 'Top 5 Judges')
    ]:
        correct_predictions = 0
        for i in range(total_problems):
            num_samples = len(dataset[i]['answer_correct'])
            true_counts = np.zeros(num_samples)
            
            for judge in judges:
                verdicts = dataset[i][judge]
                if isinstance(verdicts, list) and len(verdicts) == num_samples:
                    # Convert verdicts to array of 1s and 0s with the same shape
                    verdict_array = np.array([1 if v is True else 0 for v in verdicts])
                    true_counts += verdict_array
            
            if np.any(true_counts > 0):  # Only make prediction if we have any True verdicts
                best_sample_idx = np.argmax(true_counts)
                if dataset[i]['answer_correct'][best_sample_idx]:
                    correct_predictions += 1
        
        accuracy = correct_predictions / total_problems
        results.append({
            'strategy': f'Majority Vote {name}',
            'accuracy': accuracy,
            'correct': correct_predictions,
            'total': total_problems
        })
    
    return results

def evaluate_reward_model_strategies(dataset) -> List[Dict[str, float]]:
    """Evaluate strategies based on different reward models."""
    reward_models = get_available_reward_models(dataset)
    
    if not reward_models:
        print("Warning: No reward model scores found in dataset")
        return []

    dataset = preprocess_perm_scores(dataset)
    results = []
    total_problems = len(dataset)
    
    # Individual reward models first
    for rm in reward_models:
        correct_predictions = 0
        skipped_problems = 0
        
        for i in range(total_problems):
            if rm not in dataset[i] or not isinstance(dataset[i][rm], list):
                skipped_problems += 1
                continue
                
            scores = dataset[i][rm]
            answer_correct = dataset[i]['answer_correct']
            
            # Convert scores to float array, handling None values
            try:
                scores_array = np.array([float(s) if s is not None else np.nan for s in scores])
            except (TypeError, ValueError):
                skipped_problems += 1
                continue
                
            if len(scores_array) != len(answer_correct):
                skipped_problems += 1
                continue
                
            if np.all(np.isnan(scores_array)):
                skipped_problems += 1
                continue
                
            best_sample_idx = np.nanargmax(scores_array)
            if answer_correct[best_sample_idx]:
                correct_predictions += 1
        
        evaluated_problems = total_problems - skipped_problems
        if evaluated_problems > 0:
            accuracy = correct_predictions / evaluated_problems
            results.append({
                'strategy': f'Best {rm}',
                'accuracy': accuracy,
                'correct': correct_predictions,
                'total': evaluated_problems
            })

    # Get top performers
    top_3_rms = get_top_performers(results, 3)
    top_5_rms = get_top_performers(results, 5)
    
    print("\nTop 3 Reward Models:", top_3_rms)
    print("Top 5 Reward Models:", top_5_rms)
    
    # Evaluate aggregated strategies
    for rms, name in [
        (reward_models, 'All RMs'),
        (top_3_rms, 'Top 3 RMs'),
        (top_5_rms, 'Top 5 RMs')
    ]:
        correct_predictions = 0
        skipped_problems = 0
        
        for i in range(total_problems):
            answer_correct = dataset[i]['answer_correct']
            valid_scores = []
            
            for rm in rms:
                if rm not in dataset[i] or not isinstance(dataset[i][rm], list):
                    continue
                    
                scores = dataset[i][rm]
                try:
                    scores_array = np.array([float(s) if s is not None else np.nan for s in scores])
                    if len(scores_array) == len(answer_correct) and not np.all(np.isnan(scores_array)):
                        # Normalize scores between 0 and 1
                        min_val = np.nanmin(scores_array)
                        max_val = np.nanmax(scores_array)
                        if min_val != max_val:
                            norm_scores = (scores_array - min_val) / (max_val - min_val)
                            norm_scores = np.nan_to_num(norm_scores, 0)  # Replace NaNs with 0
                            valid_scores.append(norm_scores)
                except (TypeError, ValueError):
                    continue
            
            if not valid_scores:
                skipped_problems += 1
                continue
                
            # Average the normalized scores
            avg_scores = np.mean(valid_scores, axis=0)
            best_sample_idx = np.argmax(avg_scores)
            if answer_correct[best_sample_idx]:
                correct_predictions += 1
        
        evaluated_problems = total_problems - skipped_problems
        if evaluated_problems > 0:
            accuracy = correct_predictions / evaluated_problems
            results.append({
                'strategy': f'Aggregated Normalized {name}',
                'accuracy': accuracy,
                'correct': correct_predictions,
                'total': evaluated_problems
            })
    
    return results

def normalize_scores(scores: List[float]) -> np.ndarray:
    """Normalize scores to range [0,1]"""
    # Convert to numpy array, replacing None with np.nan
    scores_array = np.array([float(s) if s is not None else np.nan for s in scores])
    
    # If all values are nan, return zeros
    if np.all(np.isnan(scores_array)):
        return np.zeros_like(scores_array)
    
    # Get min/max ignoring nans
    min_score = np.nanmin(scores_array)
    max_score = np.nanmax(scores_array)
    
    # If min equals max, return ones for non-nan values
    if min_score == max_score:
        return np.where(np.isnan(scores_array), 0.0, 1.0)
    
    # Normalize non-nan values, set nan to 0
    normalized = (scores_array - min_score) / (max_score - min_score)
    return np.where(np.isnan(normalized), 0.0, normalized)

def evaluate_reward_model_thresholds(dataset, threshold: float) -> List[Dict[str, Any]]:
    """Evaluate reward models using thresholding on normalized scores for every individual sample."""
    reward_models = get_available_reward_models(dataset)
    
    if not reward_models:
        print("Warning: No reward model scores found in dataset")
        return []

    dataset = preprocess_perm_scores(dataset)
    
    results = []
    total_samples = sum(len(dataset[i]['answer_correct']) for i in range(len(dataset)))
    
    for rm in reward_models:
        tp = tn = fp = fn = 0
        
        for i in range(len(dataset)):
            scores = dataset[i][rm]
            ground_truth = dataset[i]['answer_correct']
            
            # Skip if scores are None or empty
            if not scores or all(s is None for s in scores):
                continue
                
            # Filter None values for min/max check
            valid_scores = [s for s in scores if s is not None]
            if not valid_scores or min(valid_scores) == max(valid_scores):
                continue
            
            # Normalize scores for this problem
            normalized_scores = normalize_scores(scores)
            
            # Compare each normalized score against threshold
            for score, is_correct in zip(normalized_scores, ground_truth):
                if score >= threshold:
                    if is_correct:
                        tp += 1
                    else:
                        fp += 1
                else:
                    if is_correct:
                        fn += 1
                    else:
                        tn += 1
        
        # Calculate metrics
        accuracy = (tp + tn) / total_samples if total_samples > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        
        results.append({
            'model': rm,
            'accuracy': accuracy,
            'precision': precision,
            'recall': recall,
            'f1': f1,
            'tp': tp,
            'fp': fp,
            'tn': tn,
            'fn': fn,
            'total': total_samples
        })
    
    return results

def format_results_table(results: List[Dict[str, Any]], title: str) -> str:
    """Format results into a pretty table."""
    headers = ['Strategy', 'Accuracy', 'Correct', 'Total']
    rows = [[
        r['strategy'],
        f"{r['accuracy']:.2%}",
        r['correct'],
        r['total']
    ] for r in results]
    
    return f"\n{title}:\n" + tabulate(rows, headers=headers, tablefmt='grid')

def format_threshold_table(results: List[Dict[str, Any]], threshold: float) -> str:
    """Format threshold results into a table."""
    headers = ['Model', 'Acc', 'Prec', 'Recall', 'F1', 'TP', 'FP', 'TN', 'FN', 'Total']
    rows = [[
        r['model'],
        f"{r['accuracy']:.2%}",
        f"{r['precision']:.2%}",
        f"{r['recall']:.2%}",
        f"{r['f1']:.2%}",
        r['tp'],
        r['fp'],
        r['tn'],
        r['fn'],
        r['total']
    ] for r in results]
    
    return f"\nThreshold {threshold} Results:\n" + tabulate(rows, headers=headers, tablefmt='grid')

def main():
    args = parse_args()
    filepath = args.dataset
    
    try:
        dataset = load_evaluation_dataset(filepath)
        
        # Preprocess PRM scores before any evaluation
        dataset = preprocess_perm_scores(dataset)
        
        # Add correct answer analysis
        correct_stats = analyze_correct_answers(dataset)
        print("\nCorrect Answer Statistics:")
        print(f"Total problems: {correct_stats['total_problems']}")
        print(f"Problems with at least one correct answer: {correct_stats['problems_with_correct']} ({correct_stats['problems_with_correct']/correct_stats['total_problems']:.1%})")
        print(f"Total correct answers across all problems: {correct_stats['total_correct_answers']}")
        print(f"Average correct answers per problem: {correct_stats['total_correct_answers']/correct_stats['total_problems']:.2f}")
        
        # Original evaluations
        baseline_results = [evaluate_first_sample_strategy(dataset)]
        rm_results = evaluate_reward_model_strategies(dataset)
        lm_judge_results = evaluate_lm_judge_strategies(dataset)
        
        # New threshold-based evaluation
        threshold_results = evaluate_reward_model_thresholds(dataset, args.reward_model_threshold_for_accuracy)
        
        # Print all results
        print(format_results_table(baseline_results, "Baseline Strategies"))
        print(format_results_table(rm_results, "Reward Model Strategies"))
        print(format_results_table(lm_judge_results, "LM Judge Strategies"))
        print(format_threshold_table(threshold_results, args.reward_model_threshold_for_accuracy))
        
    except Exception as e:
        import traceback
        print(f"Error during evaluation: {str(e)}")
        print("\nFull traceback:")
        print(traceback.format_exc())

if __name__ == "__main__":
    main()
