# for unknown reasons, we have to import pandas last
import argparse
import os
import json
import time
from tqdm import tqdm
from typing import List, Dict, Any
import jsonlines
import re
import contextlib

# Import huggingface hub for downloading
from huggingface_hub import hf_hub_download

def parse_arguments():
    parser = argparse.ArgumentParser(description="Load pre-generated proofs and verify them")
    parser.add_argument("--repo_id", type=str, default="xxx98/countdown", 
                        help="Hugging Face repository ID")
    parser.add_argument("--filename", type=str, required=True,
                        help="Filename in the repo (e.g., 'outputs___model-name.json')")
    parser.add_argument("--output_dir", type=str, required=True, 
                        help="Directory to write verification results")
    parser.add_argument("--max_workers", type=int, default=16, 
                        help="Number of workers for verification")
    parser.add_argument("--local_files_only", action="store_true", 
                        help="Use only local cached files")
    return parser.parse_args()


def download_and_load_data(repo_id: str, filename: str, local_files_only: bool = False):
    """Download and load the generated proofs data from Hugging Face"""
    print(f"Downloading {filename} from {repo_id}...")
    
    # Download the JSON file
    data_path = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type="dataset",
        local_files_only=local_files_only
    )
    
    # Load data
    with open(data_path, 'r') as f:
        data = json.load(f)
    
    print(f"Loaded {len(data)} samples")
    
    return data, data_path


def eval_res(nums, target, pred):
    if "Final operations:\n" not in pred:
        return 0
    
    # Find the last line that contains "Final operations:\n"
    final_operations_line = pred.rfind("Final operations:\n")
    operation_list = pred[final_operations_line + len("Final operations:\n"):].strip().split("\n")

    if len(operation_list) == 0:
        return 0
    
    # Hack for llama3.2
    operation_list = operation_list[:len(nums) - 1]

    # Assert that the last operation is the target
    last_operation = operation_list[-1]
    try:
        last_left, last_right = last_operation.split('=')
    except:
        print(f"Could not split last operation into lhs, rhs")
        return 0
    try:
        last_right = int(last_right)
    except:
        print(f"Could not convert last right to int")
        return 0
    if last_right != target:
        return 0

    # Verify each operation and keep track of the numbers involved
    available_numbers = nums
    for operation in operation_list:
        # Verify the operation
        try:
            left, right = operation.split('=')
        except:
            print(f"Could not split operation into lhs, rhs")
            return 0
        try:
            if eval(left) != int(right):
                print(f"Invalid operation: {operation}")
                return 0
        except Exception as e:
            print(f"Error in evaluating operation {operation}: {e}")
            return 0
        # get the numbers involved
        used_numbers = re.findall(r"\d+", left)
        for n in used_numbers:
            if int(n) not in available_numbers:
                print(f"Invalid operation: {operation}, number {n} not available in {available_numbers}")
                return 0

        available_numbers = [n for n in available_numbers if n not in used_numbers]
        available_numbers.append(int(right))

    return 1


def main():
    args = parse_arguments()

    # Set the number of workers to nproc if available, otherwise use the user-specified value
    import subprocess
    try:
        args.max_workers = int(subprocess.check_output(['nproc']).decode().strip())
    except:
        pass
    
    # Create output directory
    os.makedirs(args.output_dir, exist_ok=True)
    
    start_time = time.time()
    
    # Download and load the generated proofs
    data, data_path = download_and_load_data(args.repo_id, args.filename, args.local_files_only)
    # data is a list of dict with keys: input_text, ground_truth, model_output_list, type
    
    successes_per_problem = []
    for i, item in enumerate(data):
        nums = item["nums"]
        target = item["target"]
        search_type = item["search_type"]
        heuristic = item["heuristic"]
        rating = item["rating"]
        pred_list = item["model_output_list"]
                
        success_list = []
        for pred in pred_list:
            correct = eval_res(nums, target, pred)
            success_list.append(correct)
        successes_per_problem.append(success_list)
        
    # Calculate pass_at_k for each individual problem and store in lists
    def pass_at_k(successes, k):
        p = 1.0 * sum(successes) / len(successes)
        return 1 - (1 - p) ** k
    
    num_problems = len(data)
    
    # Calculate pass_at_k for each individual problem and store in lists
    pass_at_1_list = [pass_at_k(successes, 1) for successes in successes_per_problem]
    pass_at_2_list = [pass_at_k(successes, 2) for successes in successes_per_problem]
    pass_at_4_list = [pass_at_k(successes, 4) for successes in successes_per_problem]
    pass_at_8_list = [pass_at_k(successes, 8) for successes in successes_per_problem]
    pass_at_16_list = [pass_at_k(successes, 16) for successes in successes_per_problem]
    pass_at_32_list = [pass_at_k(successes, 32) for successes in successes_per_problem]
    pass_at_64_list = [pass_at_k(successes, 64) for successes in successes_per_problem]
    pass_at_128_list = [pass_at_k(successes, 128) for successes in successes_per_problem]
    pass_at_256_list = [pass_at_k(successes, 256) for successes in successes_per_problem]
    
    # Calculate average pass_at_k values
    pass_at_1 = sum(pass_at_1_list) / num_problems
    pass_at_2 = sum(pass_at_2_list) / num_problems
    pass_at_4 = sum(pass_at_4_list) / num_problems
    pass_at_8 = sum(pass_at_8_list) / num_problems
    pass_at_16 = sum(pass_at_16_list) / num_problems
    pass_at_32 = sum(pass_at_32_list) / num_problems
    pass_at_64 = sum(pass_at_64_list) / num_problems
    pass_at_128 = sum(pass_at_128_list) / num_problems
    pass_at_256 = sum(pass_at_256_list) / num_problems
    
    evaluation_results = {
        "repo_id": args.repo_id,
        "num_problems": num_problems,
        "pass_at_1": pass_at_1,
        "pass_at_2": pass_at_2,
        "pass_at_4": pass_at_4,
        "pass_at_8": pass_at_8,
        "pass_at_16": pass_at_16,
        "pass_at_32": pass_at_32,
        "pass_at_64": pass_at_64,
        "pass_at_128": pass_at_128,
        "pass_at_256": pass_at_256,
        "pass_at_1_list": pass_at_1_list,
        "pass_at_2_list": pass_at_2_list,
        "pass_at_4_list": pass_at_4_list,
        "pass_at_8_list": pass_at_8_list,
        "pass_at_16_list": pass_at_16_list,
        "pass_at_32_list": pass_at_32_list,
        "pass_at_64_list": pass_at_64_list,
        "pass_at_128_list": pass_at_128_list,
        "pass_at_256_list": pass_at_256_list,
        "successes_per_problem": successes_per_problem,
    }
    #print("\nEvaluation Results:")
    #print(json.dumps(evaluation_results, indent=2))

    # Save evaluation results
    eval_results_file = os.path.join(args.output_dir, "evaluation_results.json")
    with open(eval_results_file, "w") as f:
        json.dump(evaluation_results, f, indent=2)
    
    print(f"Saved evaluation results to {eval_results_file}")

    # Delete the downloaded data
    os.remove(data_path)


if __name__ == "__main__":
    main()