from tqdm import tqdm
import sys
import argparse
from utils import read_jsonl
from typing import List
import random
import math
from typing import List, Dict

def calculate_pass_at_k_with_sampling(data_list: List[Dict], k_values=[1, 4, 8, 16], num_samples=10, model='qwq'):
    """
    Calculate pass@k metric using random sampling
    
    Args:
        data_list: List containing data, each element has 'pass' field
        k_values: List of k values to calculate
        num_samples: Number of random sampling iterations
    
    Returns:
        dict: Dictionary containing various pass@k values
    """
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = item['consistency']['labels']
 
            # Skip this problem if there are not enough samples
            if len(pass_list) < k:
                continue
                
            valid_problems += 1
            
            # Multiple random sampling iterations
            success_count = 0
            for _ in range(num_samples):
                # Randomly sample k samples
                sampled_results = random.sample(pass_list, k)
                # Check if at least one passes
                if any(sampled_results):
                    success_count += 1
            
            # Calculate pass@k for this problem
            problem_pass_at_k = success_count / num_samples
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def calculate_pass_at_k_exact(data_list: List[Dict], k_values=[1, 4, 8, 16], model='qwq'):
    """
    Calculate pass@k using exact formula (more efficient method)
    pass@k = 1 - C(n-c, k) / C(n, k)
    where n is total number of samples, c is number of passing samples
    """
    def combination(n, r):
        if r > n or r < 0:
            return 0
        return math.factorial(n) // (math.factorial(r) * math.factorial(n - r))
    
    results = {}
    
    for k in k_values:
        total_score = 0
        valid_problems = 0
        
        for item in data_list:
            pass_list = item['consistency']['labels']
            n = len(pass_list)  # Total number of samples
            c = sum(pass_list)  # Number of passing samples
            
            # Skip this problem if there are not enough samples
            if n < k:
                continue
                
            valid_problems += 1
            
            # Calculate pass@k using exact formula
            if c == 0:
                problem_pass_at_k = 0
            else:
                problem_pass_at_k = 1 - combination(n - c, k) / combination(n, k)
            
            total_score += problem_pass_at_k
        
        # Calculate overall pass@k
        if valid_problems > 0:
            pass_at_k = total_score / valid_problems
            results[f'pass@{k}'] = pass_at_k
        else:
            results[f'pass@{k}'] = 0.0
    
    return results

def print_pass_at_k_results(data_list: List[Dict], method='exact'):
    """
    Print pass@k results
    
    Args:
        data_list: Data list
        method: 'exact' or 'sampling'
    """
    res_text = ''
    if method == 'exact':
        results = calculate_pass_at_k_exact(data_list)
        print("Pass@K Results (Exact Formula):")
        res_text += "Pass@K Results (Exact Formula):\n"
    else:
        results = calculate_pass_at_k_with_sampling(data_list)
        print("Pass@K Results (Random Sampling):")
        res_text += "Pass@K Results (Random Sampling):\n"

    print("-" * 30)
    res_text += "-" * 30 + '\n'
    for metric, value in results.items():
        print(f"{metric}: {value:.4f} ({value*100:.2f}%)")
        res_text += f"{metric}: {value:.4f} ({value*100:.2f}%)\n"
    
    return res_text

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--qwq_path", type=str, required=True)
    parser.add_argument("--qwen3_path", type=str, required=True)
    parser.add_argument("--res_save_path", type=str, required=True)
    args = parser.parse_args()

    qwq_path, qwen3_path = args.qwq_path, args.qwen3_path

    qwq_data = read_jsonl(qwq_path)

    qwen3_data = read_jsonl(qwen3_path)

    for qwq, qwen3 in zip(qwq_data, qwen3_data):
        qwq['consistency']['qwen3'] = qwen3['consistency']['Qwen3-32B']
        qwq['consistency']['qwq'] = qwq['consistency']['QWQ-32B']
        del qwq['consistency']['QWQ-32B']


    for d in tqdm(qwq_data):
        d['consistency']['labels'] = []
        for i in range(len(d['consistency']['qwq'])):
            l1 = '"correct"' in d['consistency']['qwq'][i].split('"is_assistant_correct"')[-1].lower()
            l2 = '"correct"' in d['consistency']['qwen3'][i].split('"is_assistant_correct"')[-1].lower()
            if l1 == True and l2 == True:
                d['consistency']['labels'].append(True)
            else:
                d['consistency']['labels'].append(False)


    print("=== Exact Formula Method ===")
    results_exact = print_pass_at_k_results(qwq_data, method='exact')

    print("\n=== Random Sampling Method ===")
    results_sampling = print_pass_at_k_results(qwq_data, method='sampling')

    with open(args.res_save_path, 'w', encoding="utf-8") as f:
        f.write(results_exact + '\n')
        f.write(results_sampling + '\n')
        f.close()
