from tqdm import tqdm, trange
import random
import re
import time
from typing import Dict, List, Union, Any, Tuple,Set
from verify_prover_v2_solutions_api import batch_verify_lean_proofs, get_sandbox_result
import json
from utils import read_jsonl, write_jsonl, basic_check, group_by_imports, combine_lean4_statements_simple, analyze_lean4_results
import argparse
import math
from typing import List, Dict
'''
Input format
{
    'id': unique identifier,
    'informal_statement': natural language theorem,
    ...
}

Output format
{
    'id':xxx,
    'informal_statement':xxx
    'formalization_prompt': xxx,
    'formal_statements_generated':[
        xxx,
        xxx,
        ...
    ],
    'pass':[
        True,
        False,
        ...
    ]
}
'''


def extract_last_lean4_code_block(text):
    """
    Extract the last Lean4 code block from a string
    
    Args:
        text (str): Input string
        
    Returns:
        str: Content of the last Lean4 code block, returns empty lean4 code block if none found
    """
    # Modified regex pattern to allow whitespace characters before the closing marker
    pattern = r'```lean4\s*\n(.*?)\n\s*```'
    
    # Use re.DOTALL flag to make . match newline characters
    matches = re.findall(pattern, text, re.DOTALL)
    
    if matches:
        # Return the last matched code block content, wrapped in lean4 code block
        return '```lean4\n' + matches[-1].strip() + '\n```'
    else:
        # If not found, return empty lean4 code block
        return '```lean4\n\n```'

def calculate_pass_at_k_with_sampling(data_list: List[Dict], k_values=[1, 4, 8, 16], num_samples=10):
    """
    Calculate pass@k metric using random sampling
    
    Args:
        data_list: List containing data, each element has 'pass' field
        k_values: List of k values to calculate
        num_samples: Number of random sampling iterations
    
    Returns:
        dict: Dictionary containing various pass@k values
    """
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = item.get('pass', [])
            
            # Skip this problem if there are not enough samples
            if len(pass_list) < k:
                continue
                
            valid_problems += 1
            
            # Multiple random sampling iterations
            success_count = 0
            for _ in range(num_samples):
                # Randomly sample k samples
                sampled_results = random.sample(pass_list, k)
                # Check if at least one passes
                if any(sampled_results):
                    success_count += 1
            
            # Calculate pass@k for this problem
            problem_pass_at_k = success_count / num_samples
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def calculate_pass_at_k_exact(data_list: List[Dict], k_values=[1, 4, 8, 16]):
    """
    Calculate pass@k using exact formula (more efficient method)
    pass@k = 1 - C(n-c, k) / C(n, k)
    where n is total number of samples, c is number of passing samples
    """
    def combination(n, r):
        if r > n or r < 0:
            return 0
        return math.factorial(n) // (math.factorial(r) * math.factorial(n - r))
    
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = item.get('pass', [])
            n = len(pass_list)  # Total number of samples
            c = sum(pass_list)  # Number of passing samples
            
            # Skip this problem if there are not enough samples
            if n < k:
                continue
                
            valid_problems += 1
            
            # Calculate pass@k using exact formula
            if c == 0:
                problem_pass_at_k = 0
            else:
                problem_pass_at_k = 1 - combination(n - c, k) / combination(n, k)
            
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def print_pass_at_k_results(data_list: List[Dict], method='exact'):
    """
    Print pass@k results
    
    Args:
        data_list: Data list
        method: 'exact' or 'sampling'
    """
    res_text = ''
    if method == 'exact':
        results = calculate_pass_at_k_exact(data_list)
        print("Pass@K Results (Exact Formula):")
        res_text += "Pass@K Results (Exact Formula):\n"
    else:
        results = calculate_pass_at_k_with_sampling(data_list)
        print("Pass@K Results (Random Sampling):")
        res_text += "Pass@K Results (Random Sampling):\n"

    print("-" * 30)
    res_text += "-" * 30 + '\n'
    for metric, value in results.items():
        print(f"{metric}: {value:.4f} ({value*100:.2f}%)")
        res_text += f"{metric}: {value:.4f} ({value*100:.2f}%)\n"
    
    return res_text


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--batch_size", type=int, default=40, help="")
    parser.add_argument("--block_size", type=int, default=2000, help="")
    args = parser.parse_args()


    data = read_jsonl(args.input_path)

    all_statements = []
    id2index = {}
    for index, d in enumerate(data):
        try:
            id = d['source_data'] + str(d['index_in_source_data'])
        except:
            id = d['id']
        id2index[id] = index
        d['pass'] = [None]* len(d['formal_statements_generated'])
        for i, p in enumerate(d['formal_statements_generated']):
            p = p.replace('```Lean4','```lean4')
            if '```lean4' in p:
                p = extract_last_lean4_code_block(p)
            dic = {'id': id, 'statement':p.replace('```lean4\n','').replace('\n```',''), 'index':i ,'pass': None}
            if 'import Mathlib' not in dic['statement']:
                d['formal_statements_generated'][i] = 'import Mathlib\n' + d['formal_statements_generated'][i]
                dic['statement'] = 'import Mathlib\n'+dic['statement']
            all_statements.append(dic)
    print(all_statements[0])
  #  assert 0
    final_results = []

    test_statements = []

    basic_fail_num = 0
    for s in all_statements:
        if not basic_check(s['statement']):
            s['pass'] = False
            basic_fail_num+= 1
            final_results.append(s)
        else:
            test_statements.append(s)
    print('basic_fail_num = ', basic_fail_num)
    grouped_test_statements = group_by_imports(test_statements)
    flatten_statements = []
    for k, v in grouped_test_statements.items():
        flatten_statements.extend(v)
        print(k, len(v))

    test_results = []
    test_compile_results = []
    for k, v0 in grouped_test_statements.items():
        for i in range(0, len(v0), args.block_size):
            v = v0[i:i+args.block_size]
            merge_statements = [combine_lean4_statements_simple([_['statement']for _ in v[i:i+args.batch_size]]) for i in range(0, len(v), args.batch_size)]
            merge_res = batch_verify_lean_proofs(merge_statements, cluster_key='lean-v6')
            for r in merge_res:
                test_results.extend(analyze_lean4_results(r))
                test_compile_results.append(r)
              #  print(test_results)
                # print(test_compile_results)
             #   assert 0
    
    fail_statements = []
    for i, r in enumerate(test_results):
        if r is None:
            fail_statements.append(flatten_statements[i]['statement'])

    print('fail_statements num = ',len(fail_statements))
    results_retry = []
    for i in range(0,len(fail_statements),40):
        temp_res = batch_verify_lean_proofs(fail_statements[i:i+40], cluster_key="lean-v6")
        results_retry.extend([_['info']['pass'] for _ in temp_res])
    index = 0
    for i, r in enumerate(test_results):
        if r is None:
            test_results[i] = results_retry[index]
            index+=1

    for s, r in zip(flatten_statements, test_results):
        if r is not None:
            s['pass'] = r
        else:
            s['pass'] = False
        final_results.append(s)
    
    save_path= args.output_path if len(args.output_path) > 0 else args.input_path.replace('.jsonl','_checkin.jsonl')

    for r in final_results:
        data_index = id2index[r['id']]
        statements_index = r['index']

        data[data_index]['pass'][statements_index] = r['pass']

    c = 0
    for d in data:
        for i in range(len(d['pass'])):
            if d['pass'][i] is None:
                d['pass'][i] = False

    write_jsonl(data, save_path, 'w')

    passk_save_path = save_path.replace('.jsonl','_syntax_pass@k_results.txt')

    print("=== Exact Formula Method ===")
    results_exact = print_pass_at_k_results(data, method='exact')

    print("\n=== Random Sampling Method ===")
    results_sampling = print_pass_at_k_results(data, method='sampling')

    with open(passk_save_path, 'w', encoding="utf-8") as f:
        f.write(results_exact + '\n')
        f.write(results_sampling + '\n')
        f.close()
