import os
import argparse
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import datasets
from lean.verifier import verify_with_deepseek_verifier
import math
import numpy as np
import json
from tqdm import tqdm
import concurrent.futures
import re


def pass_at_k(n, c, k):
    """
    Calculates the pass@k metric.
    n: total number of samples.
    c: number of correct samples.
    k: the 'k' in pass@k.
    """
    # Handle the edge case where there are no correct samples.
    if c == 0:
        return 0.0
    # If the number of incorrect samples is less than k, at least one correct one must be chosen.
    if n - c < k:
        return 1.0
    # General formula for pass@k
    return 1.0 - math.comb(n - c, k) / math.comb(n, k)


# # prompts for deepseek prover
# policy_prompt_template_deepseek = r'''Complete the following Lean 4 code:

# ```lean4
# import Mathlib
# import Aesop

# set_option maxHeartbeats 0

# open BigOperators Real Nat Topology Rat

# '''



import datasets

#DATASET="/beegfs/scratch/user/<anonymized>/fcdpg-verl/verl/data/mff-lwb-goedel-28k.parquet"

def eval(x):
    completions = x["responses"]
    data_sources = x["prompt"]
    prompt = x['reward_model']['ground_truth']

    
    # --- Parallel Verification Block ---
    res = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=8) as executor:
        # Submit verification tasks for all completions of the current prompt
        future_to_completion = {executor.submit(verify_with_deepseek_verifier, '', k, prompt): k for k in completions}
        
        # Use a progress bar for the verification of a single prompt's completions
        verification_pbar = tqdm(concurrent.futures.as_completed(future_to_completion), total=len(completions), desc="Verifying completions", leave=False)
        
        for future in verification_pbar:
            try:
                result = future.result()
                res.append(result)
            except Exception as exc:
                print(f'\nA verification task generated an exception: {exc}')
                res.append(False) # Append failure on exception
    # --- End of Parallel Block ---
        
    return {"results" : res}


def main(generation_path,model_name, split_id, total_split):
    
    print("Loading and processing dataset...")

    ds = datasets.load_dataset("parquet", data_files= f"{generation_path}/{model_name}/generation.parquet")['train']
    ds = ds.shard(num_shards=total_split, index=split_id)
    ds = ds.map(eval)
    ds.to_parquet(os.path.join(f"{generation_path}/{model_name}", f'results{split_id}_{total_split}.parquet'))



if __name__ == '__main__':
    parser = argparse.ArgumentParser(description="Run LLM evaluation with parallel verification.")
    parser.add_argument("--generation_path", type=str, required=True, help="Hugging Face model name or path.")
    parser.add_argument("--model_name", type=str, required=True, help="Hugging Face model name or path.")
    parser.add_argument("--split_id", type=int, required=True, help="Hugging Face model name or path.")
    parser.add_argument("--total_split", type=int, required=True, help="Hugging Face model name or path.")
    parser.add_argument("--nb_cpu", type=int, required=True, help="Hugging Face model name or path.")

    args = parser.parse_args()

    main(
        generation_path=args.generation_path,
        model_name=args.model_name,
        split_id=args.split_id,
        total_split=args.total_split
    )
