from vllm import LLM, SamplingParams
import json
from boxed_extract import *
import argparse
import os
from collections import OrderedDict
from boxed_extract import *
from transformers import AutoTokenizer
import math

def compute_pass_at_k(scores, k):
    """
    Compute pass@k.
    Args:
        scores: list of per-sample scores (0 or 1)
        k: k value
    Returns:
        pass@k value
    """
    if len(scores) < k:
        return 0.0
    
    c = sum(scores)
    
    return 1.0 if c > 0 else 0.0

def merge_baselines(dir_path):
    folder_path = dir_path
    from collections import defaultdict

    input_folder = folder_path
    output_file = os.path.join(folder_path, "merged_temp.jsonl")
    merged_data = defaultdict(dict)

    for filename in os.listdir(input_folder):
        if filename.startswith("qwen3-8B-think"):
            file_path = os.path.join(input_folder, filename)
            with open(file_path, "r", encoding="utf-8") as f:
                for line in f.readlines():
                    data = json.loads(line)
                    key = data["problem"]
                    merged_data[key]["problem"] = data["problem"]
                    ground_truth = data["ground_truth"][1:] if data["ground_truth"].startswith('0') else data["ground_truth"]
                    merged_data[key]["ground_truth"] = ground_truth
                    temp = list(data["temp_acc"].keys())[0]
                    acc = data["temp_acc"][temp]
                    # acc = []
                    # for sol in data['solutions']:
                    #     acc.append(compute_score(sol, ground_truth))
                    # acc = round(sum(acc)/len(acc) * 100, 2)
                    if 'temp_acc' not in merged_data[key]:
                        merged_data[key]["temp_acc"] = {temp: acc}
                    else:
                        merged_data[key]["temp_acc"].update({temp: acc})

    with open(output_file, "w", encoding="utf-8") as f:
        for key, data in merged_data.items():
            sorted_items = sorted(
                data["temp_acc"].items(),
                key=lambda x: float(x[0])
            )
            data["temp_acc"] = OrderedDict(sorted_items)
            sorted_items = sorted(
                data["temp_acc"].items(),
                key=lambda x: float(x[1])
            )
            data["opt_temp"] = {sorted_items[-1][0]: sorted_items[-1][1]}
            f.write(json.dumps(data, ensure_ascii=False) + "\n")

    print(f"Merged results saved to {output_file}")

def write_ascii_table(txt_path: str, dataset_name: str, avg_pass: float, avg_acc: float, k: int):
    """Write a one-row ASCII table with headers and values."""
    headers = ["", f"Pass@{k}", "Acc"]
    row = [dataset_name, f"{avg_pass:.2f}", f"{avg_acc:.2f}"]
    col_widths = [max(len(headers[i]), len(row[i])) + 2 for i in range(3)]

    def make_border() -> str:
        return "+" + "+".join("-" * w for w in col_widths) + "+\n"

    border = make_border()
    header_line = "|" + "|".join(headers[i].center(col_widths[i]) for i in range(3)) + "|\n"
    data_line = "|" + "|".join(row[i].center(col_widths[i]) for i in range(3)) + "|\n"

    table_str = border + header_line + border + data_line + border
    with open(txt_path, "w") as txt_file:
        txt_file.write(table_str)
if __name__ == "__main__":
    def extract_model_name(path):
        """Extract model name from model path"""
        path = path.rstrip('/')
        parts = path.split('/')
        if len(parts) >= 2:
            return parts[-1]  
        elif len(parts) == 1:
            return parts[0]
        else:
            return 'unknown'
    

    parser = argparse.ArgumentParser()
    parser.add_argument('--temp', type=float, default=1.0)
    parser.add_argument('--top_p', type=float, default=1.0)
    parser.add_argument('--top_k', type=int, default=-1)
    parser.add_argument('--rp', type=float, default=1.0)
    parser.add_argument('--k', type=int, default=16)
    parser.add_argument('--model_name_or_path', type=str, default='')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--dataset', type=str, default='aime24')
    parser.add_argument('--tp_size', type=int, default=4)
    parser.add_argument('--max_tokens', type=int, default=32768)
    args = parser.parse_args()

    ckpt_name = extract_model_name(args.model_name_or_path)

    temp = args.temp
    k = args.k
    seed = args.seed

    with open(f'pathtoyour_test_file/{args.dataset}.jsonl', 'r') as f:
        data = [json.loads(line) for line in f.readlines()]

    sampling_params = SamplingParams(temperature=temp, top_p=args.top_p, top_k=args.top_k, max_tokens=args.max_tokens, n=k, seed=seed, repetition_penalty=args.rp)


    llm = LLM(model=args.model_name_or_path, tensor_parallel_size=args.tp_size)

    tokenizer = llm.get_tokenizer()

    problems = [
        tokenizer.apply_chat_template(
            [{"role": "user", "content": item['problem'] + '\nMake sure you output the final answer within \\boxed{}.'}],
            tokenize=False,
            add_generation_prompt=True,
        ) for item in data
    ]

    ground_truths = [
        item['gt'] for item in data
    ]


    if not os.path.exists(f'generation_log/{args.dataset}'):
        os.makedirs(f'generation_log/{args.dataset}')
    outputs = llm.generate(problems, sampling_params)

    # # for qwen3
    # outputs = llm.generate(text, sampling_params) 

    with open(f'generation_log/{args.dataset}/{ckpt_name}-temp{temp}-top_p{args.top_p}-top_k{args.top_k}-rp{args.rp}-pass@{k}-max_tokens{args.max_tokens}-seed{seed}.json', 'w') as f:
        all_scores = []
        all_acc = []
        for idx, output_group in enumerate(outputs):
            solutions = []
            temps = []
            gt = ground_truths[idx]
            scores = []
            logprobs = []
            top_ps = []
            for output in output_group.outputs:
                generated_text = output.text
                temp = getattr(output, 'temps', None)
                top_p = getattr(output, 'top_p', None)
                score = compute_score(generated_text, gt)
                scores.append(score)
                solutions.append(generated_text)
                if temp is not None:
                    temps.append(temp)
                if top_p is not None:
                    top_ps.append(top_p)
                # logprobs.append(output.logprobs)
            pass_at_k = compute_pass_at_k(scores, k)
            all_scores.append(pass_at_k)
            all_acc.append(round(sum(scores)/len(scores)*100, 2))
            # f.write(json.dumps({'problem': data[idx]['problem'], 'ground_truth': ground_truths[idx], 'temp_acc': {args.temp: round(sum(scores)/len(scores)*100, 2)}, 'solutions': solutions, 'temp': temps}, ensure_ascii=False)+'\n')
            f.write(json.dumps({
                'problem': problems[idx],  # data[idx]['problem']
                'ground_truth': ground_truths[idx], 
                'temp_acc': {args.temp: round(sum(scores)/len(scores)*100, 2)}, 
                f'pass_at_{k}': round(pass_at_k * 100, 2),
                'solutions': solutions,
                'temp': temps,
                'top_p': top_ps,
                # 'logprobs': logprobs
            }, ensure_ascii=False)+'\n')
        
        avg_pass_at_k = round(sum(all_scores)/len(all_scores)*100, 2)
        avg_acc = round(sum(all_acc)/len(all_acc), 2)
        print(f"Overall avg Pass@{k}: {avg_pass_at_k}%")
        print(f"Overall avg Acc: {avg_acc}%")

        # Save ASCII table to .txt
        txt_path = os.path.splitext(f.name)[0] + '.txt'
        write_ascii_table(txt_path, 'AIME24', avg_pass_at_k, avg_acc, k)


