from vllm import LLM, SamplingParams, AsyncLLMEngine
from transformers import AutoTokenizer
import re
import os
import random
import sys
import argparse
from transformers import AutoTokenizer
from utils import read_jsonl, write_jsonl, write_res
import math
from typing import List, Dict
'''
Input
{
    'id':xxx,
    'informal_statement':xxx
    'formalization_prompt': xxx,
    'formal_statements_generated':[
        xxx,
        xxx,
        ...
    ],
    'pass':[
        True,
        False,
        ...
    ]
}

Output
{
    'id':xxx,
    'informal_statement':xxx
    'formalization_prompt': xxx,
    'formal_statements_generated':[
        xxx,
        xxx,
        ...
    ],
    'pass':[
        True,
        False,
        ...
    ]
    'consistency':[
        'qwq':[
            responses 1,
            '',
            ...
        ],
        'qwen3':
        [
            responses 1,
            '',
            ...
        ]
        ...
    ]
}
'''

def extract_last_lean4_code_block(text):
    """
    Extract the last Lean4 code block from a string
    
    Args:
        text (str): Input string
        
    Returns:
        str: Content of the last Lean4 code block, returns empty lean4 code block if none found
    """
    # Modified regex pattern to allow whitespace characters before the closing marker
    pattern = r'```lean4\s*\n(.*?)\n\s*```'
    
    # Use re.DOTALL flag to make . match newline characters
    matches = re.findall(pattern, text, re.DOTALL)
    
    if matches:
        # Return the last matched code block content, wrapped in lean4 code block
        return '```lean4\n' + matches[-1].strip() + '\n```'
    else:
        # If not found, return empty lean4 code block
        return '```lean4\n\n```'

def get_prompt(tokenizer, informal_statement, formal_statement):


    template = '''Role: Lean & Formal Verification Expert

Input:
- Mathematical_Text: A math problem and its answer (no proof).
- Lean4Code: A Lean 4 theorem statement formalizing the problem. Proof is intentionally omitted (e.g., sorry).

Goal:
Determine if the Lean theorem statement is an exact and faithful formalization of the mathematical problem.  
**Do not evaluate or consider the answer or the proof. Your sole task is to verify the correctness of the formalization.**

Evaluation Stages (All required):

1. Mathematical Text Analysis  
   Identify all structurally and semantically relevant components of the mathematical problem, including variables, types, quantifiers, constraints, logic structure, conclusion, and so on. The analysis should be based on the actual content of the text.

2. Lean4 Code Analysis (ignore proof part)  
   Extract all structurally and semantically relevant components from the Lean statement, including variables, types, conditions, quantifiers, constraints, the final claim, and so on. The analysis should reflect the actual content present in the Lean code.

3. Comparative Analysis  
   Check for exact correspondence between the math and Lean statements; you may refer to aspects like:
   - Semantic alignment, logic structure, and quantifier correctness.
   - Preservation of constraints and boundary assumptions.
   - Accurate typing and use of variables.
   - Strict adherence to Lean's specific syntactic and semantic rules in interpreting the Lean code.
   - Syntactic validity and proper Lean usage (free from errors).
   - Use of symbols and constructs without semantic drift.
   - No missing elements, no unjustified additions, and no automatic corrections or completions.

4. Accuracy Confirmation  
   If correct: clearly confirm why all elements match.  
   If incorrect: list all mismatches and explain how each one affects correctness.

Note: While the analysis may be broad and open to interpreting all relevant features, the final judgment must be based only on what is explicitly and formally expressed in the Lean statement.  
**Do not consider or assess any part of the proof. Your judgment should be entirely about the accuracy of the statement formalization.**

Output Format:
Return exactly one JSON object:
```json
{
    "reasons": "1. Mathematical Text Analysis: [...]2.  Lean4 Code Analysis (: [...]3. Comparative Analysis: [...]4. Accuracy Confirmation: [...match confirmation or list of discrepancies...]",
    "is_assistant_correct": "[Correct/Incorrect]"
}
```

— Start of Mathematical_Text —
{informal_statement}
— End of Mathematical_Text —

— Start of Lean4Code —
{formal_statement}
— End of Lean4Code —
'''.strip()
    
    messages = [
    {"role": "system", "content": "You are an expert in mathematics and Lean 4."},
    {"role": "user", "content": template.replace('{informal_statement}',informal_statement).replace('{formal_statement}',formal_statement)}
    ]

    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True,
    )
    return text

def calculate_pass_at_k_with_sampling(data_list: List[Dict], k_values=[1, 4, 8, 16], num_samples=10, model='qwq'):
    """
    Calculate pass@k metric using random sampling
    
    Args:
        data_list: List containing data, each element has 'pass' field
        k_values: List of k values to calculate
        num_samples: Number of random sampling iterations
    
    Returns:
        dict: Dictionary containing various pass@k values
    """
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = ['"correct"' in r.split('"is_assistant_correct"')[-1].lower() for r in item['consistency'][model]]
 
            # Skip this problem if there are not enough samples
            if len(pass_list) < k:
                continue
                
            valid_problems += 1
            
            # Multiple random sampling iterations
            success_count = 0
            for _ in range(num_samples):
                # Randomly sample k samples
                sampled_results = random.sample(pass_list, k)
                # Check if at least one passes
                if any(sampled_results):
                    success_count += 1
            
            # Calculate pass@k for this problem
            problem_pass_at_k = success_count / num_samples
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def calculate_pass_at_k_exact(data_list: List[Dict], k_values=[1, 4, 8, 16], model='qwq'):
    """
    Calculate pass@k using exact formula (more efficient method)
    pass@k = 1 - C(n-c, k) / C(n, k)
    where n is total number of samples, c is number of passing samples
    """
    def combination(n, r):
        if r > n or r < 0:
            return 0
        return math.factorial(n) // (math.factorial(r) * math.factorial(n - r))
    
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = ['"correct"' in r.split('"is_assistant_correct"')[-1].lower() for r in item['consistency'][model]]
            n = len(pass_list)  # Total number of samples
            c = sum(pass_list)  # Number of passing samples
            
            # Skip this problem if there are not enough samples
            if n < k:
                continue
                
            valid_problems += 1
            
            # Calculate pass@k using exact formula
            if c == 0:
                problem_pass_at_k = 0
            else:
                problem_pass_at_k = 1 - combination(n - c, k) / combination(n, k)
            
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def print_pass_at_k_results(data_list: List[Dict], method='exact',model='qwq'):
    """
    Print pass@k results
    
    Args:
        data_list: Data list
        method: 'exact' or 'sampling'
    """
    if method == 'exact':
        results = calculate_pass_at_k_exact(data_list, model=model)
        print("Pass@K Results (Exact Formula):")
    else:
        results = calculate_pass_at_k_with_sampling(data_list, model=model)
        print("Pass@K Results (Random Sampling):")
    
    print("-" * 30)
    for metric, value in results.items():
        print(f"{metric}: {value:.4f} ({value*100:.2f}%)")
    
    return results


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--input_path", type=str, required=True)
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--output_path", type=str, required=True)
    parser.add_argument("--gpu",type=str,required=True)
    parser.add_argument("--n", type=int, default=1, help="")
    parser.add_argument("--max_length", type=int, default=8192+4096, help="")
    parser.add_argument("--batch_size", type=int, default=1024*8, help="")
    parser.add_argument("--temperature", type=float, default=0, help="")
    parser.add_argument("--only_save_prompts", type=bool, default=False, help="")
    args = parser.parse_args()

    model_path = args.model
    os.environ["CUDA_VISIBLE_DEVICES"] = f"{args.gpu}" 

    try:
        tokenizer = AutoTokenizer.from_pretrained(model_path,use_fast=False)
    except:
        tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote=True)
    
    data = read_jsonl(args.input_path)

    model_name = args.model.split('/')[-1]
    save_path= args.output_path if len(args.output_path) > 0 else args.input_path.replace('.jsonl',f'_consistency_{model_name}.jsonl')

    prompts = []
    id2index = {}
    for index, d in enumerate(data):
        try:
            id = d['source_data'] + str(d['index_in_source_data'])
        except:
            id = d['id']
        id2index[id] = index
        for i, s, p in zip(range(len(d['formal_statements_generated'])),d['formal_statements_generated'], d['pass']):
            if p:
                s = s.replace('```Lean4','```lean4')
                if '```lean4' in s:
                    s = extract_last_lean4_code_block(s)
                prompt = get_prompt(tokenizer, informal_statement=d['informal_statement'], formal_statement=s)
                prompts.append({'id':id, 'index': i,'prompt':prompt})

    if args.only_save_prompts:
        write_jsonl(prompts, args.input_path.replace('.jsonl',f'_prompts_{args.model}.jsonl'),'w')
    else:
        sampling_params = SamplingParams(temperature=args.temperature, max_tokens=args.max_length, n=args.n)

        if os.path.exists(save_path):
            shift = len(read_jsonl(save_path))
        else:
            shift = 0

        print(f"start at shift = {shift}")
        if shift < len(prompts):
            print("load model")

            llm = LLM(model=model_path, tensor_parallel_size=2, gpu_memory_utilization=0.95, dtype="bfloat16", swap_space=16, disable_custom_all_reduce=True,seed=12)

            batch_inputs = []
            for i in range(len(prompts)):
                if i < shift:
                    continue
                batch_inputs.append(prompts[i])
                if len(batch_inputs) >= args.batch_size:
                    print(f"====== {i - args.batch_size} ～ {i} / {len(prompts)} ======")
                    batch_outputs = llm.generate(batch_inputs, sampling_params)
                    batch_res = []
                    for idx, output in enumerate(batch_outputs):
                        batch_res.append([output.outputs[_].text for _ in range(len(output.outputs))])
                    write_res(save_path, batch_res=batch_res)
                    batch_inputs = []

            if len(batch_inputs) > 0:
                print(f"======{len(prompts)} / {len(prompts)} ======")
                batch_outputs = llm.generate(batch_inputs, sampling_params)
                batch_res = []
                for idx, output in enumerate(batch_outputs):
                        batch_res.append([output.outputs[_].text for _ in range(len(output.outputs))])
                write_res(save_path, batch_res=batch_res)

            all_responses = read_jsonl(save_path)

        all_responses = read_jsonl(save_path)
        for d in data:
            if not 'consistency' in d:
                d['consistency'] = {}
            if not args.model in d['consistency']:
                d['consistency'][args.model] = [''] * len(d['pass'])

        for p, r in zip(prompts, all_responses):
            data_index = id2index[p['id']]
            statement_index = p['index']
            if isinstance(r,list):
                data[data_index]['consistency'][args.model][statement_index] = r[0]
            else:
                data[data_index]['consistency'][args.model][statement_index] = r
        
        write_jsonl(data, save_path, 'w')

        print("=== Exact Formula Method ===")
        results_exact = print_pass_at_k_results(data, method='exact',model=args.model)
        
        print("\n=== Random Sampling Method ===")
        results_sampling = print_pass_at_k_results(data, method='sampling',model=args.model)
