import argparse
import logging
from typing import Dict, List, Optional, Union, Tuple
from datasets import load_dataset, load_from_disk, Dataset
from tabulate import tabulate
import numpy as np

from collections import Counter

from dataclasses import asdict

import hashlib 

import os 
import json

from datetime import datetime 

from weaver.config_handler import VerifierHandler, VerifierConfig
from weaver.unsupervised_methods import (
    FirstSample, 
    MajorityVote, 
    HighestScoringRM, 
    HighestScoringLM, 
    NaiveRMEnsemble, 
    NaiveLMEnsemble, 
    NaiveEnsemble,
    NaiveBinaryEnsemble 
)
from weaver.supervised_methods import (
    TopJudges, 
    TopRMs, 
    WeightedEnsemble, 
    Majority_X_then_Top_Y_Judges, 
    Majority_X_then_Top_Y_RMs,
    WeightedJudgeEnsemble,
    WeightedRMEnsemble,
    LogisticRegressionEnsemble,
    NaiveBayesEnsemble
)
from weaver.weak_supervision import WeakSupervision

import glob 

from weaver.constants import DATASET_TO_REWARD_MODELS, DATASET_TO_LM_JUDGES

import time, random 
time.sleep(random.uniform(0, 1))


# Configure logging
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

verifier_classes = {
    'FirstSample': FirstSample,
    'MajorityVote': MajorityVote,
    'HighestScoringRM': HighestScoringRM,
    'HighestScoringLM': HighestScoringLM,
    'NaiveRMEnsemble': NaiveRMEnsemble,
    'NaiveLMEnsemble': NaiveLMEnsemble,
    'NaiveEnsemble': NaiveEnsemble,
    'NaiveBinaryEnsemble': NaiveBinaryEnsemble,
    'TopJudges': TopJudges,
    'TopRMs': TopRMs,
    'WeightedEnsemble': WeightedEnsemble,
    'LogisticRegressionEnsemble': LogisticRegressionEnsemble,
    'NaiveBayesEnsemble': NaiveBayesEnsemble,
    'WeightedRMEnsemble': WeightedRMEnsemble,
    'WeightedJudgeEnsemble': WeightedJudgeEnsemble,
    'Majority_X_then_Top_Y_Judges': Majority_X_then_Top_Y_Judges,
    'Majority_X_then_Top_Y_RMs': Majority_X_then_Top_Y_RMs,
    'WeakSupervision': WeakSupervision
}

np.random.seed(0)
random.seed(0)

class NpEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        if isinstance(obj, np.floating):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super(NpEncoder, self).default(obj)

def normalize_reward_model_scores_global(dataset: Dataset, reward_models: List[str]) -> Dataset:
    """Normalize reward model scores across all samples for each model separately."""
    for i, model in enumerate(reward_models):
        if model not in dataset.column_names:
            logger.warning(f"Reward model column {model} not found in dataset")
            continue
        
        all_scores = []
        for row_scores in dataset[model]:
            if isinstance(row_scores, (list, np.ndarray)):
                all_scores.extend([s for s in row_scores if s is not None and isinstance(s, (int, float))])

        if not all_scores:
            continue 

        min_score = min(all_scores)
        max_score = max(all_scores)
        if min_score == max_score:
            raise ValueError(f"{model}'s scores are all the same across the dataset.")
        else:
            normalized_data = []
            for i, row_scores in enumerate(dataset[model]):
                if isinstance(row_scores, (list, np.ndarray)):
                    normalized_data.append([
                        (s - min_score) / (max_score - min_score)
                        if s is not None and isinstance(s, (int, float))
                        else None 
                        for s in row_scores
                    ])
                else:
                    normalized_data.append([None] * len(dataset[model][0]))
            dataset = dataset.remove_columns([model])
            dataset = dataset.add_column(model, normalized_data)

    return dataset
            
def normalize_reward_model_scores_per_problem(dataset: Dataset, reward_models: List[str]) -> Dataset:
    """Normalize reward model scores across all generations per problem for each model separately."""
    for model in reward_models:
        if model not in dataset.column_names:
            logger.warning(f"Reward model column {model} not found in dataset")
            continue
    
        normalized_data = []
        for row_scores in dataset[model]:
            valid_scores = [s for s in row_scores if s is not None]
            if not valid_scores:
                normalized_data.append([None] * len(valid_scores))
                continue 
            
            min_score = min(valid_scores)
            max_score = max(valid_scores)

            if max_score == min_score:
                normalized_data.append([0.5] * len(valid_scores))
                continue 
            
            normalized = [(s - min_score) / (max_score - min_score) if s is not None else None for s in row_scores]
            normalized_data.append(normalized)

        dataset = dataset.remove_columns([model])
        dataset = dataset.add_column(model, normalized_data)
    return dataset
    
def compute_reward_thresholds(dataset, reward_threshold, reward_models, ws_class_balance=None):
    n_rms = len(reward_models)
    n_problems = len(dataset)
    try:
        # if the threshold is a float
        t = float(reward_threshold)
        new_column = {rm: [t] * n_problems for rm in reward_models}
    except ValueError:
        all_normalized_scores = np.column_stack([dataset[col] for col in reward_models]).reshape(n_problems, n_rms, -1)
        all_normalized_scores = all_normalized_scores.transpose(1, 0, 2) # n_rms, n_problems, n_generations

        all_labels = np.array([sample['answer_correct'] for sample in dataset]).reshape(n_problems, -1)
        if "per_problem" in reward_threshold:
            if "mean" in reward_threshold:
                means = all_normalized_scores.mean(axis=2) # means is (n_rm, n_problems)
                new_column = {rm : list(means[i]) for i, rm in enumerate(reward_models)}
            elif "median" in reward_threshold:
                medians = np.median(all_normalized_scores, axis=2)
                new_column = {rm : list(medians[i]) for i, rm in enumerate(reward_models)}
            elif "threshold_cb" in reward_threshold:
                new_column = {}
                for i in range(n_rms):
                    t_per_rm = []
                    for j in range(n_problems):
                        cb = ws_class_balance if ws_class_balance is not None else all_labels[j].mean()
                        sorted_row = np.sort(all_normalized_scores[i, j])
                        index = int(np.ceil((1-cb) * len(sorted_row))) - 1
                        t_per_rm.append(sorted_row[index])
                    new_column[reward_models[i]] = t_per_rm 
        else:
            all_normalized_scores = all_normalized_scores.reshape(n_rms, -1) # n_rms, total generations
            if "mean" in reward_threshold:
                means = all_normalized_scores.mean(axis=1) # dim n_rms
                new_column = {rm : [means[i]] * n_problems for i, rm in enumerate(reward_models)}
            elif "median" in reward_threshold:
                medians = np.median(all_normalized_scores, axis=1)
                new_column = {rm : [medians[i]] * n_problems for i, rm in enumerate(reward_models)}
            elif "threshold_cb" in reward_threshold:
                cb = ws_class_balance if ws_class_balance is not None else all_labels.mean()
                new_column = {}
                for i in range(n_rms):
                    scores_no_none = np.where(all_normalized_scores[i] == None, 0, all_normalized_scores[i])
                    sorted_row = np.sort(scores_no_none)
                    index = int(np.ceil((1-cb) * len(sorted_row))) - 1
                    new_column[reward_models[i]] = [sorted_row[index]] * n_problems

    for rm, thresholds in new_column.items():
        dataset = dataset.add_column(f"{rm}_threshold", thresholds)

    return dataset

def is_list_of_lists(var):
    return isinstance(var, list) and all(isinstance(item, list) for item in var)

def is_list_of_list_of_lists(var):
    return isinstance(var, list) and all(is_list_of_lists(item) for item in var)

def preprocess_dataset(dataset_path: str, max_rows: Optional[int] = None, 
                       normalization: str = 'global', reward_threshold: Union[str, float] = 0.5,
                       ws_class_balance: float = None) -> Dataset:
    """
    Load dataset and preprocess:
    1. Handle None values in judge verdicts
    2. Normalize reward model scores
    3. Rename columns if needed
    
    Args:
        dataset_path: Path to dataset or HuggingFace dataset name
        max_rows: Maximum number of rows to process
        
    Returns:
        Processed dataset with normalized scores and cleaned verdicts
    """
    # Load dataset
    logger.info(f"Loading dataset from {dataset_path}")
    try:
        dataset = load_from_disk(dataset_path)
    except:
        if dataset_path == "anonymous_research/MATH_with_RM_LJ_UT_v1":
            dataset = load_dataset(dataset_path)['data']
            dataset = dataset.remove_columns('extracted_answers')
            dataset = dataset.rename_columns({
                'extracted_answers_using_R.E.': 'extracted_answers'
            })
            logger.info("Renaming extracted_answers_using_R.E. to extracted_answers")
        else:
            dataset = load_dataset(dataset_path)['data']
    
    # Select subset if max_rows specified
    if max_rows is not None:
        dataset = dataset.select(range(min(max_rows, len(dataset))))
    
    # List of known reward model columns
    reward_models = DATASET_TO_REWARD_MODELS[dataset_path]
    lm_judges = DATASET_TO_LM_JUDGES[dataset_path]
    for v in reward_models + lm_judges:
        if is_list_of_list_of_lists(dataset[v]) and len(dataset[v][0][0]) == 1:
            logger.info(f"Flattening {v} scores.")
            flattened_scores = [[gen[0] for gen in problem] for problem in dataset[v]]
            dataset = dataset.remove_columns([v])
            dataset = dataset.add_column(v, flattened_scores)

    # Normalize reward model scores
    logger.info(f"Performing {normalization} reward model normalization")
    if normalization == 'global':
        dataset = normalize_reward_model_scores_global(dataset, reward_models)    
    else:
        dataset = normalize_reward_model_scores_per_problem(dataset, reward_models)    

    # compute reward threshold per problem per RM
    logger.info(f"Computing reward thresholds; strategy: {reward_threshold}")
    dataset = compute_reward_thresholds(dataset, reward_threshold, reward_models, ws_class_balance)
    return dataset

def clean_verifiers(dataset_path: str, dataset: Dataset) -> Dataset:
    judge_columns = DATASET_TO_LM_JUDGES[dataset_path]
    for judge in judge_columns:        
        # Convert None values to False in judge verdicts
        cleaned_verdicts = []
        for row_verdicts in dataset[judge]:
            if not row_verdicts:
                cleaned_verdicts.append([])
                continue
            cleaned = [False if v is None else v for v in row_verdicts]
            cleaned_verdicts.append(cleaned)
        
        # Update dataset with cleaned verdicts
        dataset = dataset.remove_columns([judge])
        dataset = dataset.add_column(judge, cleaned_verdicts)

    rm_columns = DATASET_TO_REWARD_MODELS[dataset_path]
    for rm in rm_columns:
        cleaned_scores = []
        for row_scores in dataset[rm]:
            cleaned = [0 if v is None else v for v in row_scores]
            cleaned_scores.append(cleaned)

        dataset = dataset.remove_columns([rm]) 
        dataset = dataset.add_column(rm, cleaned_scores)

    return dataset

def calculate_binary_metrics(predictions: List[bool], ground_truth: List[bool]) -> Dict[str, float]:
    """
    Calculate metrics for binary predictions.
    
    Args:
        predictions: List of boolean predictions
        ground_truth: List of boolean ground truth values
        
    Returns:
        Dictionary of metrics
    """
    tp = sum(1 for p, gt in zip(predictions, ground_truth) if p is not None and (p and gt))
    tn = sum(1 for p, gt in zip(predictions, ground_truth) if p is not None and (not p and not gt))
    fp = sum(1 for p, gt in zip(predictions, ground_truth) if p is not None and (p and not gt))
    fn = sum(1 for p, gt in zip(predictions, ground_truth) if p is not None and (not p and gt))
    
    total = len([p for p in predictions if p is not None])
    
    # Calculate derived metrics
    accuracy = (tp + tn) / total if total > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'true_positives': tp,
        'true_negatives': tn,
        'false_positives': fp,
        'false_negatives': fn,
        'total': total
    }

def calculate_row_precision(predictions: List[bool], ground_truth: List[bool]) -> float:
    """
    Calculate precision for a single row.
    
    Args:
        predictions: List of boolean predictions for the row
        ground_truth: List of boolean ground truth values for the row
        
    Returns:
        Precision value (0 if no positive predictions)
    """
    if sum(ground_truth) == 0:
        return None # ignore rows that don't have any true answers 
    
    tp = sum(1 for p, gt in zip(predictions, ground_truth) if p and gt)
    fp = sum(1 for p, gt in zip(predictions, ground_truth) if p and not gt)
    # return 0 if there are no positive predictions 
    return tp / (tp + fp) if (tp + fp) > 0 else 0

def calculate_row_recall(predictions: List[bool], ground_truth: List[bool]) -> float:
    """
    Calculate recall for a single row.
    
    Args:
        predictions: List of boolean predictions for the row
        ground_truth: List of boolean ground truth values for the row
        
    Returns:
        Recall value (None if no positive true values, in which case we ignore the row)
    """
    tp = sum(1 for p, gt in zip(predictions, ground_truth) if p and gt)
    cb = sum(ground_truth)
    
    return tp / cb if cb > 0 else None


def generate_rm_metric(dataset, metrics, rm):
    all_predictions = []
    all_ground_truth = []
    row_precisions = []
    row_recalls = []

    selections = [] 
    solvable_selections = []
    
    # Process each row
    for i in range(len(dataset)):
        row_scores = dataset[rm][i]
        row_ground_truth = dataset['answer_correct'][i]
        reward_threshold = dataset[f"{rm}_threshold"][i]
                    
        # Ensure we're dealing with numeric scores
        valid_scores = []
        valid_ground_truth = []
        for score, gt in zip(row_scores, row_ground_truth):
            # Check if score is a number and not a list/dict/etc
            if score is not None and isinstance(score, (int, float)):
                valid_scores.append(score)
                valid_ground_truth.append(gt)
        
        if valid_scores:
            best_idx = np.array(valid_scores).argmax() # selects the first one 
            selections.append(valid_ground_truth[best_idx]) 

            if any(valid_ground_truth) == 1:
                # if the problem is solvable, at least one correct answer
                solvable_selections.append(valid_ground_truth[best_idx])


            # Convert scores to binary predictions using threshold
            row_predictions = [score >= reward_threshold for score in valid_scores]
            
            # Calculate row precision
            row_precision = calculate_row_precision(row_predictions, valid_ground_truth)
            if row_precision is not None:
                row_precisions.append(row_precision)

            row_recall = calculate_row_recall(row_predictions, valid_ground_truth)
            if row_recall is not None:
                row_recalls.append(row_recall)
            
            # Add to overall predictions for other metrics
            all_predictions.extend(row_predictions)
            all_ground_truth.extend(valid_ground_truth)
    
    # Calculate metrics
    if all_predictions:
        metrics[rm] = calculate_binary_metrics(all_predictions, all_ground_truth)
        # Add normalized precision
        metrics[rm]['normalized_precision'] = (
            sum(row_precisions) / len(row_precisions) if len(row_precisions) != 0 else 0.0
        )
        metrics[rm]['normalized_recall'] = (
            np.array(row_recalls).mean() if len(row_recalls) != 0 else 0.0
        )
        metrics[rm]['selection_accuracy'] = np.array(selections).mean()

        metrics[rm]['solvable_selection_accuracy'] = np.array(solvable_selections).mean()
    else:
        logger.warning(f"No valid predictions found for {rm}")

    return metrics 

def generate_reward_model_metrics(dataset, dataset_path) -> Dict[str, Dict[str, float]]:
    """
    Generate metrics for specified reward models in the dataset.
    
    Args:
        dataset: HuggingFace dataset with verifier results
        
    Returns:
        Nested dictionary of metrics for each model
    """
    metrics = {}
    
    # List of reward models to process
    reward_models = DATASET_TO_REWARD_MODELS[dataset_path]
    
    # Process reward models
    for column in reward_models:
        if column not in dataset.column_names:
            logger.warning(f"Reward model column {column} not found in dataset")
            continue

        if "_step" in column:
            continue 
        logger.info(f"Processing metrics for reward model: {column}")  # Keep full name
        
        metrics = generate_rm_metric(dataset, metrics, column)    
    return metrics

def generate_lm_judge_metrics(dataset, dataset_path) -> Dict[str, Dict[str, float]]:
    """Generate metrics for LM judges."""
    metrics = {}
        
    lm_columns = DATASET_TO_LM_JUDGES[dataset_path]
    for column in lm_columns:
        logger.info(f"Processing metrics for LM judge: {column}")
        all_predictions = []
        all_ground_truth = []
        row_precisions = []
        row_recalls = []
        selections = []
        solvable_selections = []

        # Process each row
        for i in range(len(dataset)):
            verdicts = dataset[column][i]
            if isinstance(verdicts[0], List) and len(verdicts[0]) == 1:
                verdicts = [v[0] for v in verdicts]


            ground_truth = dataset['answer_correct'][i]            
            
            # Calculate row-level precision and recall
            row_precision = calculate_row_precision(verdicts, ground_truth)
            if row_precision is not None:
                row_precisions.append(row_precision)

            row_recall = calculate_row_recall(verdicts, ground_truth)
            if row_recall is not None:
                row_recalls.append(row_recall)

            verdicts_no_nones = np.array([v if v is not None else 0 for v in verdicts])

            best_idx = verdicts_no_nones.argmax() 
            selections.append(ground_truth[best_idx])
            if any(ground_truth) == 1:
                # if the problem is solvable, at least one correct answer
                solvable_selections.append(ground_truth[best_idx])

            all_predictions.extend(verdicts)
            all_ground_truth.extend(ground_truth)
        
        if all_predictions:
            metrics[column] = calculate_binary_metrics(all_predictions, all_ground_truth)
            # Add normalized precision
            metrics[column]['normalized_precision'] = (
                sum(row_precisions) / len(row_precisions) if len(row_precisions) != 0 else 0.0
            )
            metrics[column]['normalized_recall'] = (
                np.array(row_recalls).mean() if len(row_recalls) != 0 else 0.0
            )
            metrics[column]['selection_accuracy'] = np.array(selections).mean()
            metrics[column]['solvable_selection_accuracy'] = np.array(solvable_selections).mean()
        else:
            logger.warning(f"No valid predictions found for {column}")
    
    return metrics

def display_metrics(metrics: Dict[str, Dict[str, float]]):
    """
    Display metrics in a formatted table.
    
    Args:
        metrics: Nested dictionary of metrics for each model
    """
    # Prepare data for tabulate
    headers = ['Model', 'Selection Acc', 'Norm Prec', 'Norm Recall', 'Accuracy', 'Precision', 'Recall', 'F1', 'TP', 'TN', 'FP', 'FN', 'Total']
    rows = []
    
    for model, model_metrics in sorted(metrics.items()):
        row = [
            model,
            f"{model_metrics['selection_accuracy']:.3f}",
            f"{model_metrics['normalized_precision']:.3f}",
            f"{model_metrics['normalized_recall']:.3f}",
            f"{model_metrics['accuracy']:.3f}",
            f"{model_metrics['precision']:.3f}",
            f"{model_metrics['recall']:.3f}",
            f"{model_metrics['f1']:.3f}",
            model_metrics['true_positives'],
            model_metrics['true_negatives'],
            model_metrics['false_positives'],
            model_metrics['false_negatives'],
            model_metrics['total']
        ]
        rows.append(row)
    
    print("\nMetrics:")
    print(tabulate(rows, headers=headers, tablefmt='grid'))

def generate_cache_key(dataset_path, config):
    """
    Generate a unique cache key based on the sequence of verifiers and their parameters, as well as the base config.
    """
    all_info = {
        "dataset": dataset_path,
        "config_so_far": config 
    }

    key_string = json.dumps(all_info, sort_keys=True)
    return hashlib.md5(key_string.encode()).hexdigest()[:10]

def apply_verifiers(dataset: Dataset, verifiers: List[VerifierConfig],
                   verifier_metrics: Dict, base_config, dataset_path: str, 
                   use_cached_stages: bool, cache_dir: str,
                   verifier_subset: Union[List[str], int], mv_as_voter: bool, return_mask: bool=False) -> Union[Dict, Tuple[Dict, np.ndarray]]:
    """Apply verification methods to the dataset sequentially and return combined results."""
    results = {}
    current_mask = None
        
    # Apply verifiers sequentially
    for stage, verifier_config in enumerate(verifiers, 1):
        base_config['verifiers'].append(asdict(verifier_config))
        cache_key = generate_cache_key(dataset_path, base_config)
        os.makedirs(cache_dir, exist_ok=True)
        cache_path = os.path.join(cache_dir, f"{cache_key}.json")

        if os.path.exists(cache_path) and use_cached_stages:
            with open(cache_path, "r") as f:
                cached_data = json.load(f, cls=NpEncoder)
                current_mask = cached_data['current_mask']
                weights = cached_data['weights']
            logger.info(f"Loaded cached results for verifier sequence up to stage {stage}")
        else:
            logger.info(f"Stage {stage}: Applying {verifier_config.name}")
        
            # Initialize and apply verifier
            verifier_class = verifier_classes[verifier_config.name]
            verifier = verifier_class(
                dataset_path=dataset_path, 
                dataset=dataset,
                verifier_metrics=verifier_metrics, 
                filter_strategy = verifier_config.params.get('filter_strategy', None),
                filter_strategy_param = verifier_config.params.get('filter_strategy_param', None),
                verifier_subset=verifier_subset,
                mv_as_voter=mv_as_voter
            )
            output = verifier.filter(dataset, current_mask, **verifier_config.params)
            if type(output) is tuple:
                filtered_indices = output[0]
                weights = output[1]
            else:
                filtered_indices = output 
                weights = None

            if filtered_indices is None:
                return None 
            
            # Convert filtered indices to binary mask
            current_mask = np.zeros((len(filtered_indices), verifier.scores.shape[1])) # (n_problems, n_generations)
            for i, indices in enumerate(filtered_indices):
                if len(indices) != 0:
                    current_mask[i, indices] = 1


            # Cache the current mask
            with open(cache_path, "w") as f:
                json.dump({
                    "current_mask": current_mask.tolist(),
                    "weights": weights
                }, f, cls=NpEncoder)
            logger.info(f"Cached results for verifier sequence up to stage {stage}")

        mask_array = np.array(current_mask)

        truth_array = np.array(dataset['answer_correct']).astype(int)

        tp = np.sum((mask_array == 1) & (truth_array == 1))
        fp = np.sum((mask_array == 1) & (truth_array == 0))
        fn = np.sum((mask_array == 0) & (truth_array == 1))
        tn = np.sum((mask_array == 0) & (truth_array == 0))
        
        # Calculate row-wise metrics
        row_true_positives = np.sum((mask_array == 1) & (truth_array == 1), axis=1)
        row_false_positives = np.sum((mask_array == 1) & (truth_array == 0), axis=1)
        row_ground_truth_positives = np.sum(truth_array == 1, axis=1)
        row_predicted_positives = row_true_positives + row_false_positives
        
        # Calculate precision
        valid_rows = row_ground_truth_positives > 0
        precision_mask = valid_rows & (row_predicted_positives > 0)
        precision_values = np.zeros(len(mask_array))
        precision_values[precision_mask] = (row_true_positives[precision_mask] / 
                                        row_predicted_positives[precision_mask])
        problem_precisions = precision_values[valid_rows]


        # calculate solvable selection accuracy and selection accuracy
        # select the first index where the mask is 1 
        selected_first_indices = np.argmax(mask_array, axis=1)
        selections = truth_array[np.arange(len(truth_array)), selected_first_indices]

        solvable_mask = np.any(truth_array == 1, axis=1)
        solvable_selections = selections[solvable_mask]

        selection_accuracy = selections.mean()
        solvable_selection_accuracy = solvable_selections.mean()
        

        # Calculate recall
        recall_mask = row_ground_truth_positives > 0
        problem_recalls = row_true_positives[recall_mask] / row_ground_truth_positives[recall_mask]
    
        # Calculate final performance
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
        accuracy = (tp + tn) / (tp + tn + fp + fn)
        norm_precision = sum(problem_precisions) / len(problem_precisions) if len(problem_precisions) != 0 else 0.0
        norm_recall = sum(problem_recalls) / len(problem_recalls) if len(problem_recalls) != 0 else 0.0


        logger.info(f"\nStage {stage} statistics:")
        logger.info(f"Selection accuracy: {selection_accuracy}")
        logger.info(f"Solvable selection accuracy: {solvable_selection_accuracy}")
        logger.info(f"Normalized precision: {norm_precision:.3f}")
        logger.info(f"TP: {tp}, TN: {tn}, FP: {fp}, FN: {fn}")

        # Store results
        results[f"stage_{stage}"] = {
            'mask': current_mask.copy(),
            'metrics': {
                'norm_precision': norm_precision,
                'norm_recall': norm_recall,
                'accuracy': accuracy, 
                'selection_accuracy': selection_accuracy,
                'solvable_selection_accuracy': solvable_selection_accuracy,
                'precision': precision,
                'recall': recall,
                'f1': f1,
                'true_positives': tp,
                'true_negatives': tn,
                'false_positives': fp,
                'false_negatives': fn,
                'total': tp + tn + fp + fn
            },
            'weights': weights
        }
    
    # Add final combined results
    results['Combined'] = results[f"stage_{len(verifiers)}"]
    if return_mask:
        return results, mask_array
    else:
        return results


def display_results(verification_results: dict, enabled_verifiers):
    """Display verification results showing progression through stages."""
    print("\nVerification Results:")
    
    headers = ['Stage', 'Norm Prec', 'Norm Recall', 'Accuracy', 'Sel_Acc', 'Solve_Sel_Acc', 'Precision', 'Recall', 'F1', 'TP', 'TN', 'FP', 'FN']
    rows = []
    
    print('-' * 160)  # Add top line
    
    for i, (stage_name, results) in enumerate(verification_results.items()):
        if stage_name == 'Combined':
            continue

        stage_name = f"Stage {i+1}: {enabled_verifiers[i].name}"
            
        metrics = results['metrics']
        rows.append([
            stage_name,
            f"{metrics['norm_precision']:.3f}",
            f"{metrics['norm_recall']:.3f}",
            f"{metrics['accuracy']:.3f}",
            f"{metrics['selection_accuracy']:.3f}",
            f"{metrics['solvable_selection_accuracy']:.3f}",
            f"{metrics['precision']:.3f}",
            f"{metrics['recall']:.3f}",
            f"{metrics['f1']:.3f}",
            metrics['true_positives'],
            metrics['true_negatives'],
            metrics['false_positives'],
            metrics['false_negatives']
        ])
    
    # Add final combined results
    if 'Combined' in verification_results:
        metrics = verification_results['Combined']['metrics']
        rows.append([
            'Final Combined',
            f"{metrics['norm_precision']:.3f}",
            f"{metrics['norm_recall']:.3f}",
            f"{metrics['accuracy']:.3f}",            
            f"{metrics['selection_accuracy']:.3f}",
            f"{metrics['solvable_selection_accuracy']:.3f}",
            f"{metrics['precision']:.3f}",
            f"{metrics['recall']:.3f}",
            f"{metrics['f1']:.3f}",
            metrics['true_positives'],
            metrics['true_negatives'],
            metrics['false_positives'],
            metrics['false_negatives']
        ])
    
    print(tabulate(rows, headers=headers, tablefmt='pipe', numalign='right'))
    print('-' * 160)  # Add bottom line

def load_or_generate_metrics(dataset_path: str, dataset: Dataset, reward_threshold: float, rm_cache_path: str, lm_cache_path: str) -> Dict[str, Dict[str, float]]:
    """
    Load metrics from cache if available, otherwise generate and cache them.
    
    Args:
        dataset_path: Path or name of the dataset
        dataset: The loaded dataset
        
    Returns:
        Combined dictionary of LM judge and reward model metrics
    """
    # Try to load cached metrics
    reward_model_metrics = {}
    lm_judge_metrics = {}
    
    if os.path.exists(rm_cache_path):
        logger.info(f"Loading cached reward model metrics from {rm_cache_path}")
        with open(rm_cache_path, 'r') as f:
            reward_model_metrics = json.load(f)
    else:
        logger.info(f"Generating reward model metrics to {rm_cache_path}")
        reward_model_metrics = generate_reward_model_metrics(dataset, dataset_path=dataset_path)
        # Cache the metrics
        with open(rm_cache_path, 'w') as f:
            json.dump(reward_model_metrics, f, indent=2)
            
    if os.path.exists(lm_cache_path):
        logger.info(f"Loading cached LM judge metrics from {lm_cache_path}")
        with open(lm_cache_path, 'r') as f:
            lm_judge_metrics = json.load(f)
    else:
        logger.info(f"Generating LM judge metrics to {lm_cache_path}")
        lm_judge_metrics = generate_lm_judge_metrics(dataset, dataset_path)
        # Cache the metrics
        with open(lm_cache_path, 'w') as f:
            json.dump(lm_judge_metrics, f, indent=2)

    metrics = {**reward_model_metrics, **lm_judge_metrics}

    # if MV scorer is available, add it.
    metrics_dir = os.path.dirname(rm_cache_path)
    mv_cache_path = os.path.join(metrics_dir, f"mv_metrics_threshold_{reward_threshold}.json")
    if os.path.exists(mv_cache_path):
        logger.info(f"Loading cached mv metrics from {lm_cache_path}")
        with open(mv_cache_path, 'r') as f:
            mv_metrics = json.load(f)
        metrics = {**metrics, **mv_metrics}

    return metrics


def create_output_file(dir, dataset, config):
    date = datetime.now().strftime("%m%d%Y")

    combined_config = {
        "dataset": dataset, 
        "config": config
    }

    config_str = json.dumps(combined_config, sort_keys=True)
    filename = hashlib.md5(config_str.encode()).hexdigest()[:10]
    folder = os.path.join(dir, date)
    os.makedirs(folder, exist_ok=True)
    filepath = os.path.join(folder, f"{filename}.json")
    return filepath, date 


def load_dataset_and_metrics(dataset_name, reward_threshold, global_params, mv_as_voter, metrics_dir="/scr/biggest/anonymous-user/scaling-verification/weaver/metrics"):
    hf_dataset_cache = os.getenv("HF_DATASETS_CACHE")
    if hf_dataset_cache is None:
        hf_dataset_cache = "../weaver_cache/"
        os.makedirs(hf_dataset_cache, exist_ok=True)
    dataset_path = os.path.join(hf_dataset_cache, dataset_name.replace("/", "_") + f"_{reward_threshold}")
    metrics_dir = f"{metrics_dir}/{dataset_name.replace('/', '-')}"
    os.makedirs(metrics_dir, exist_ok=True)
    rm_cache_path = os.path.join(metrics_dir, f"reward_model_metrics_threshold_{reward_threshold}.json")
    lm_cache_path = os.path.join(metrics_dir, "lm_judge_metrics.json")
    if global_params['normalization'] != 'global':
        dataset_path += f"_{global_params['normalization']}"
        rm_cache_path = rm_cache_path.split(".json")[0] + f"_{global_params['normalization']}.json"
        lm_cache_path = lm_cache_path.split(".json")[0] + f"_{global_params['normalization']}.json"

    if os.path.exists(dataset_path) and os.path.exists(rm_cache_path) and os.path.exists(lm_cache_path):
        logger.info(f"Loading preprocessed dataset from {dataset_path}")
        dataset = load_from_disk(dataset_path)
        logger.info(f"Loading cached reward model metrics from {rm_cache_path}")
        with open(rm_cache_path, 'r') as f:
            reward_model_metrics = json.load(f)
        logger.info(f"Loading cached LM judge metrics from {lm_cache_path}")
        with open(lm_cache_path, 'r') as f:
            lm_judge_metrics = json.load(f)

        verifier_metrics = {**reward_model_metrics, **lm_judge_metrics}    

        mv_cache_path = os.path.join(metrics_dir, f"mv_metrics_threshold_{reward_threshold}.json")
        if os.path.exists(mv_cache_path):
            logger.info(f"Loading cached mv metrics from {lm_cache_path}")
            with open(mv_cache_path, 'r') as f:
                mv_metrics = json.load(f)
            verifier_metrics = {**verifier_metrics, **mv_metrics}
    else:
        logger.info("Preprocessing dataset...")
        dataset = preprocess_dataset(
            dataset_name,
            max_rows=global_params.get('max_rows'),
            normalization=global_params.get('normalization'),
            reward_threshold=reward_threshold,
            ws_class_balance=global_params.get('ws_class_balance', None)
        )
        # Generate initial metrics using dataset (still with None scores)
        logger.info(f"Generating metrics at {rm_cache_path}, {lm_cache_path}")
        verifier_metrics = load_or_generate_metrics(
            dataset_name, dataset, reward_threshold,
            rm_cache_path, lm_cache_path
        )

        logger.info("Cleaning verifiers...") # remove Nones
        dataset = clean_verifiers(dataset_name, dataset)

        logger.info(f"Saving dataset to {dataset_path}") # remove Nones
        dataset.save_to_disk(dataset_path)

    # once the dataset and verifier metrics are loaded or generated, check if we need to construct the majority verifier
    if mv_as_voter:
        if "mv_verifier" not in dataset.column_names or "mv_verifier" not in verifier_metrics:
            logger.info("Constructing majority verifier...")
            dataset, mv_metric = construct_mv_verifier(dataset, dataset_name, dataset_path, reward_threshold)
            verifier_metrics = {**verifier_metrics, **mv_metric}

    return dataset, verifier_metrics


def construct_mv_verifier(dataset, dataset_name, dataset_path, reward_threshold):
    # create the majority verifier 'scorer'.
    if "mv_verifier" not in dataset.column_names:
        extracted_answers = np.array(dataset['extracted_answers'])
        mv_data = []
        for i, row in enumerate(extracted_answers):
            c = Counter(row)
            freqs = np.array(list(c.values()))
            freqs = freqs/freqs.sum() # normalize the frequencies
            if len(freqs) != 1:
                min_freq, max_freq = freqs.min(), freqs.max() 
                if min_freq == max_freq:
                    freqs = np.ones_like(freqs)
                else:
                    freqs = (freqs - min_freq)/(max_freq - min_freq) # scale them to be from 0 to 1 
            freqs = {ans: freqs[j]  for j, ans in enumerate(c.keys())} 
            mv_row = [freqs[ans] for ans in row ] # assign the normalized frequencies to the answers, as the mv scorer.
            mv_data.append(mv_row)

        dataset = dataset.add_column("mv_verifier", mv_data)
        # set the "reward threshold" for the mv verifier.
        dataset = compute_reward_thresholds(dataset, reward_threshold, ["mv_verifier"])
        ds = Dataset.from_dict(dataset.to_dict())
        ds.save_to_disk(dataset_path)

    # evaluate the mv_verifier
    mv_metric = generate_rm_metric(dataset, {}, "mv_verifier")
    metrics_dir = f"weaver/metrics/{dataset_name.replace('/', '-')}"
    mv_cache_path = os.path.join(metrics_dir, f"mv_metrics_threshold_{reward_threshold}.json")
    with open(mv_cache_path, 'w') as f:
        json.dump(mv_metric, f, indent=2)

    return dataset, mv_metric

def main():
    """Main entry point for dataset verification."""
    parser = argparse.ArgumentParser(description='Apply verifiers to dataset')
    
    # Required arguments
    parser.add_argument('--dataset', type=str, required=True,
                      help='Path to dataset or HuggingFace dataset name')
    parser.add_argument('--config', type=str, required=True,
                      help='Path to verifier configuration JSON')
    parser.add_argument('--output_dir', type=str, default="../weaver_results/",
                      help='Path to results directory')
    parser.add_argument('--redo', action="store_true",
                      help='If set, will overwrite existing results.')
    parser.add_argument('--use_cached_stages', action="store_true",
                      help='If set, will use the cached results of a sequence of substages.')
    parser.add_argument('--cache_dir', type=str, default="../weaver_cache/",
                      help='Path to cache directory')
    parser.add_argument('--save_individual_metrics', action="store_true",
                      help='If set, will save individual metric results to the results file too.')
    args = parser.parse_args()
    
    try:
        # Load configuration
        logger.info(f"Loading configuration from {args.config}")
        config_handler = VerifierHandler(args.config)
        global_params = config_handler.get_global_params()

        output_file, date = create_output_file(args.output_dir, args.dataset, config_handler.config)
        search_pattern = output_file.replace(date, "*")
        existing_files = glob.glob(search_pattern)
        if existing_files and not args.redo:
            logger.info(f"Results file already exists at {existing_files[0]}. Skipping.")
            return 
        
        # Set logging level from config
        if global_params.get('verbose', False):
            logger.setLevel(logging.DEBUG)

        reward_threshold = global_params.get('reward_threshold', 0.5)
        mv_as_voter = global_params.get("mv_as_voter", False)
        dataset, verifier_metrics = load_dataset_and_metrics(args.dataset, reward_threshold, global_params, mv_as_voter)

        # Get enabled verifiers
        enabled_verifiers = config_handler.get_enabled_verifiers()
        if not enabled_verifiers:
            logger.warning("No enabled verifiers found in configuration")
            return
            
        verifier_subset = global_params.get('verifier_subset', None)

        # Apply verifiers
        logger.info("Applying verifiers...")
        verification_results = apply_verifiers(
            dataset=dataset,
            verifiers=enabled_verifiers,
            verifier_metrics=verifier_metrics,
            base_config=config_handler.get_config_no_verifiers(),
            dataset_path=args.dataset,
            use_cached_stages=args.use_cached_stages,
            cache_dir=args.cache_dir,
            verifier_subset=verifier_subset,
            mv_as_voter=mv_as_voter,
        )

        if verification_results is None:
            return 

        # display results
        display_results(verification_results, enabled_verifiers)
        # save all results 
        combined_data = {
            "dataset": args.dataset, 
            "config": config_handler.config,
            "results": verification_results,
        }


        if args.save_individual_metrics:
            combined_data['individual_metrics'] = verifier_metrics if verifier_subset is None else {v: verifier_metrics[v] for v in verifier_subset}

        with open(output_file, "w") as f:
            json.dump(combined_data, f, indent=2, cls=NpEncoder)

        logger.info(f"Saved results to {output_file}.")
          
    except Exception as e:
        logger.error(f"An error occurred: {str(e)}")
        raise

if __name__ == "__main__":
    main()
