import argparse
import numpy as np
from itertools import combinations
from tqdm import tqdm
from pebble import ProcessPool
from concurrent.futures import TimeoutError

from grader import *

from parser import *
from utils import load_jsonl
from python_executor import PythonExecutor


def evaluate(data_name, prompt_type, samples: list=None, file_path: str=None, max_num_samples=None, execute=False):
    assert samples or file_path, "samples or file_path must be provided"
    if not samples:
        samples = list(load_jsonl(file_path))
    if 'idx' in samples[0]:
        samples = {sample['idx']: sample for sample in samples}.values()
        samples = sorted(samples, key=lambda x: x['idx']) 
    else:
        samples = [dict(idx=idx, **sample) for idx, sample in enumerate(samples)]

    if max_num_samples:
        print(f"max_num_samples: {max_num_samples} / {len(samples)}")
        samples = samples[:max_num_samples]
    
    # parse gt
    for sample in samples:
        sample['gt_cot'], sample['gt'] = parse_ground_truth(sample, data_name)
    params = [(idx, pred, sample['gt']) for idx, sample in enumerate(samples) for pred in sample['pred']]

    scores = []
    timeout_cnt = 0

    with ProcessPool(max_workers=64) as pool:
        future = pool.map(math_equal_process, params, timeout=3)
        iterator = future.result()
        with tqdm(total=len(samples), desc="Evaluate") as progress_bar:
            while True:
                try:
                    result = next(iterator)
                    scores.append(result)
                except StopIteration:
                    break
                except TimeoutError as error:
                    print(error)
                    scores.append(False)
                    timeout_cnt += 1
                except Exception as error:
                    print(error.traceback)
                    exit()
                progress_bar.update(1) 

    idx = 0
    score_mat = []
    for sample in samples:
        sample['score'] = scores[idx: idx+len(sample['pred'])]
        assert len(sample['score']) == len(sample['pred'])
        score_mat.append(sample['score'])
        idx += len(sample['pred'])

    max_len = max([len(s) for s in score_mat])

    for i, s in enumerate(score_mat):
        if len(s) < max_len:
            score_mat[i] = s + [s[-1]] * (max_len - len(s)) # pad

    # output mean of each column of scores
    col_means= np.array(score_mat).mean(axis=0)
    mean_score = list(np.round(col_means * 100, decimals=4))

    mean_mean_score = np.mean(mean_score)
    print(f"mean_score: {mean_score}")
    
    # Calculate pass@k metrics using combinatorial approach
    # For each sample with n samplings, compute all C(n, k) combinations
    # For each combination, if at least one answer is correct, the combination passes
    # pass@k = total_passed_combinations / total_combinations across all problems
    max_k = max([len(sample['score']) for sample in samples])  # maximum number of samplings
    pass_at_k = {}
    
    # Calculate pass@k for k = 1, 2, 4, 8, 16, ... up to max_k
    k_values = [1]
    k = 2
    while k <= max_k:
        k_values.append(k)
        k *= 2
    if max_k not in k_values:
        k_values.append(max_k)
    
    for k in k_values:
        total_passed_combinations = 0
        total_combinations = 0
        
        for sample in samples:
            # Use original scores (before padding)
            sample_scores = sample['score']
            n = len(sample_scores)
            
            # Skip if we don't have enough samplings
            if n < k:
                continue
            
            # Generate all C(n, k) combinations
            # Each combination is a tuple of indices
            for combo_indices in combinations(range(n), k):
                # Get the scores for this combination
                combo_scores = [sample_scores[i] for i in combo_indices]
                # If at least one is correct, this combination passes
                if any(combo_scores):
                    total_passed_combinations += 1
                total_combinations += 1
        
        # Calculate pass@k as percentage
        if total_combinations > 0:
            pass_at_k[f"pass@{k}"] = np.round(total_passed_combinations / total_combinations * 100, decimals=4)
        else:
            pass_at_k[f"pass@{k}"] = 0.0
    
    print(f"pass@k: {pass_at_k}")
    
    result_json = {
        "num_samples": len(samples),
        "num_scores": len(scores),
        "timeout_samples": timeout_cnt,
        "empty_samples": len([s for s in samples if not s['pred'][-1]]),
        "acc": mean_score[0],
        "all_acc": mean_score,
        "mean_acc": mean_mean_score,
        "pass_at_k": pass_at_k,
    }

    # each type score
    if "type" in samples[0]:
        type_scores = {}
        for sample in samples:
            if sample['type'] not in type_scores:
                type_scores[sample['type']] = []
            type_scores[sample['type']].append(sample['score'][-1])
        type_scores = {k: np.round(np.array(v).mean() * 100, decimals=1) for k, v in type_scores.items()}
        type_scores = {k: v for k, v in sorted(type_scores.items(), key=lambda item: item[0])}
        result_json['type_acc'] = type_scores

    print(result_json)
    return samples, result_json


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_name", type=str, default="math")
    parser.add_argument("--prompt_type", type=str, default="tool-integrated")
    parser.add_argument("--file_path", type=str, default=None, required=True)
    parser.add_argument("--max_num_samples", type=int, default=None)
    parser.add_argument("--execute", action="store_true")
    args = parser.parse_args()
    return args

if __name__ == "__main__":
    args = parse_args()
    evaluate(data_name=args.data_name, prompt_type=args.prompt_type, file_path=args.file_path,
             max_num_samples=args.max_num_samples, execute=args.execute)
