import argparse
import json
from datasets import Dataset, load_dataset
from sal.utils.qwen_math_parser import extract_answer, math_equal
import os
from tqdm import tqdm

def compute_accuracy(dataset, pred_column="pred", answer_column="answer"):
    """Compute overall accuracy"""
    correct = 0
    total = len(dataset)
    
    for example in tqdm(dataset, desc=f"Evaluating {pred_column}"):
        try:
            pred = extract_answer(example[pred_column], "math")
            answer = example[answer_column]
            if math_equal(pred, answer):
                correct += 1
        
        except Exception as e:
            print(f"\nError at index {i}, pred_col={pred_column}")
            print(f"pred: {example[pred_column]}")
            print(f"answer: {example[answer_column]}")
            print(f"Exception: {e}")
            continue
    
    return {
        "accuracy": correct / total,
        "correct": correct,
        "total": total
    }

def compute_accuracy_by_n(dataset, n_values=[1, 2, 4, 8, 16, 32, 64]):
    """Compute accuracy for different n values"""
    results = {}
    
    for n in n_values:
        if f"pred_weighted@{n}" in dataset.column_names:
            acc = compute_accuracy(dataset, f"pred_weighted@{n}")
            results[f"accuracy@{n}"] = acc["accuracy"]
            print(f"Accuracy with weighted prediction @{n}: {acc['accuracy']:.4f} ({acc['correct']}/{acc['total']})")
    
    return results

def compute_all_prediction_accuracies(dataset):
    """Compute accuracy for all available prediction methods"""
    results = {}
    
    pred_columns = []
    for col in dataset.column_names:
        if col.startswith("pred"):
            pred_columns.append(col)
    
    if not pred_columns:
        print("Warning: No prediction columns found!")
        return results
    
    print(f"Found prediction columns: {pred_columns}")
    print("-" * 60)
    
    for pred_col in pred_columns:
        try:
            acc = compute_accuracy(dataset, pred_col)
            results[pred_col] = acc
            print(f"{pred_col:25s}: {acc['accuracy']:.4f} ({acc['correct']:3d}/{acc['total']:3d})")
        except Exception as e:
            print(f"Error computing accuracy for {pred_col}: {e}")
    
    return results

def main():
    parser = argparse.ArgumentParser(description="Compute accuracy for test-time compute results")
    parser.add_argument("input_file", type=str, help="Input dataset file (JSON/JSONL/Arrow)")
    parser.add_argument("--output", type=str, help="Output file for results (JSON)")
    parser.add_argument("--n_values", nargs="+", type=int, default=[4, 16, 64, 256], 
                       help="N values to compute accuracy for")
    
    args = parser.parse_args()
    
    print(f"Loading dataset from: {args.input_file}")
    
    if args.input_file.endswith('.json') or args.input_file.endswith('.jsonl'):
        dataset = load_dataset('json', data_files=args.input_file, split='train')
    elif args.input_file.endswith('.arrow'):
        dataset = Dataset.from_file(args.input_file)
    else:
        dataset = load_dataset(args.input_file, split='train')
    
    print(f"Dataset loaded with {len(dataset)} examples")
    print(f"Columns: {dataset.column_names}")
    print("=" * 60)
    
    all_results = compute_all_prediction_accuracies(dataset)
    
    print("\n" + "=" * 60)
    print("SUMMARY:")
    
    n_results = compute_accuracy_by_n(dataset, args.n_values)
    
    final_results = {
        "all_predictions": all_results,
        "by_n_values": n_results,
        "dataset_info": {
            "total_examples": len(dataset),
            "columns": dataset.column_names
        }
    }
    
    if all_results:
        best_method = max(all_results.keys(), key=lambda k: all_results[k]['accuracy'])
        best_acc = all_results[best_method]['accuracy']
        print(f"\nBest method: {best_method} with accuracy {best_acc:.4f}")
    
    if args.output:
        output_dir = os.path.dirname(args.output)
        if output_dir and not os.path.exists(output_dir):
            os.makedirs(output_dir, exist_ok=True)
        with open(args.output, 'w') as f:
            json.dump(final_results, f, indent=2)
        print(f"\nResults saved to: {args.output}")
    
    return final_results

if __name__ == "__main__":
    main()