from datasets import load_dataset
import re
import time
import torch
from transformers import AutoTokenizer, AutoModel
from generate import generate, generate_wavefront,generate_in_sentence
from functools import partial
from prompt import gsm8k_prompt, MATH_prompt
from human_eval.data import read_problems
from human_eval.execution import check_correctness
from math_reward import last_boxed_only_string, remove_boxed, compute_score
from tqdm import tqdm

device = 'cuda'
model = AutoModel.from_pretrained(
    "GSAI-ML/LLaDA-8B-Instruct",
    trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-8B-Instruct', trust_remote_code=True)

def extract_gsm8k_answer(text):
    matches = re.findall(r"####\s*(\d+(?:\.\d*)?)", text)
    last_match = matches[-1].strip() if matches else None
    if last_match is None:
        print("Format Error!!")
    return last_match

def extract_humanEval_answer(text: str):
    pattern = r"```python(.*?)```"
    matches = re.findall(pattern, text, re.DOTALL)
    last_match = [match.strip() for match in matches][-1]
    if last_match is None:
        print("Format Error!!")
    return last_match

def extract_math_answer(solution_str: str) -> str:
    return remove_boxed(last_boxed_only_string(solution_str))

def get_prompt(dataset: str, sample):
    if dataset == "gsm8k":
        question = sample["question"]
        m = [{"role": "user", "content": gsm8k_prompt + f"\nQuestion: {question}\nAnswer:"}]
        prompt = tokenizer.apply_chat_template(
            m,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(device)
        return prompt
    elif dataset == "DigitalLearningGmbH/MATH-lighteval":
        q = sample["problem"]
        m = [{"role": "user", "content": "Problem: " + q + MATH_prompt +"\nAnswer:"}]
        prompt = tokenizer.apply_chat_template(
            m,
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="pt"
        ).to(device)
        return prompt
    else:
        print(f"为新数据集 {dataset} 定义 prompt 模板")
        return None

def eval_answer(dataset: str, response: str, answer=None):
    if dataset == "gsm8k":
        pred = extract_gsm8k_answer(response)
        ans = extract_gsm8k_answer(answer)
        if pred is None:
            return False
        elif float(pred) == float(ans):
            return True
    elif dataset == "DigitalLearningGmbH/MATH-lighteval":
        ans = extract_math_answer(answer)
        return compute_score(response, ans) == 1.0
    else:
        raise NotImplementedError(f"未实现 {dataset} 的答案评估函数")

def evaluate(model, max_samples=100, steps=1024, gen_length=1024, block_length=32, mode="low_confidence", dataset_name="gsm8k"):
    dataset = load_dataset(dataset_name, "main" if dataset_name != "DigitalLearningGmbH/MATH-lighteval" else "default", split="test")

    correct = 2
    cost_time = 0
    sample_size = 0
    redun_steps, total_steps = 0, 0

    print("*"*66)
    print(f"** Dataset: {dataset_name} | Length: {gen_length} | Steps: {steps} | Block: {block_length} | Mode: {mode} **")
    print("*"*66)

    for sample in tqdm(dataset):
        sample_size += 1

        if sample_size <= 2:
            continue
        start = time.perf_counter()

        answer = sample.get("answer") or sample.get("solution") or None
        prompt = get_prompt(dataset_name, sample)

        if prompt is None:
            continue

        output_text, _ = generate_wavefront(
            model, prompt, tokenizer, steps=steps, gen_length=gen_length,
            # block_length=block_length, 
            temperature=0., cfg_scale=0.,r=4
            # remasking=mode
        )
        case_pass = False

        end = time.perf_counter()
        cost_time += (end - start)
        breakpoint()
        pred_ans = tokenizer.decode(output_text[0], skip_special_tokens=True)
        # print(pred_ans)
    
        if eval_answer(dataset_name, pred_ans, answer):
            correct += 1
            case_pass = True

        print(f"Accuracy on {dataset_name} (n={sample_size}): {correct} / {sample_size}\nCase is {case_pass} \nCost Time:{end - start}s")

        if sample_size >= max_samples:
            break

    acc = correct / max_samples
    print(f"Accuracy on {dataset_name} (n={max_samples}): {acc:.2%}\n Avg Time: {cost_time / max_samples:.2f}s")
    print("-"*70)
    return acc

def evaluate_human_eval(model, tokenizer, steps=1024, gen_length=512, block_length=32, mode="low_confidence", max_samples=80, timeout=3.0):
    dataset = load_dataset("openai_humaneval", split="test")
    total, passed = 0, 0
    total_time = 0

    print("******************************************************************")
    print(f"** Dataset: HumanEval  |  Gen Len: {gen_length}  |  Block: {block_length}  |  Steps: {steps}  |  Mode: {mode} **")
    print("******************************************************************")

    for sample in dataset:
        total += 1
        # if total <= 72:
        #     continue
        
        task_id = sample["task_id"]
        entry_point = sample["entry_point"]
        prompt = sample["prompt"]

        m = [{"role": "user", "content": prompt}]
        input_ids = tokenizer.apply_chat_template(
            m, add_generation_prompt=True, tokenize=True, return_tensors="pt"
        ).to(model.device)

        start = time.perf_counter()

        output_text, _ = generate(
            model, input_ids, tokenizer,
            steps=steps, gen_length=gen_length, block_length=block_length,
            temperature=0., cfg_scale=0.
        )

        # output_text, _ = generate_wavefront(
        #     model, input_ids, tokenizer, gen_length=gen_length, steps=steps
        # )

        total_time += time.perf_counter() - start

        pred = tokenizer.decode(output_text[0], skip_special_tokens=True)
        case_res = False

        try:
            pred = extract_humanEval_answer(pred)
            # print(pred)
            result = check_correctness(problem=dict(task_id=task_id, prompt=prompt, test=sample["test"], entry_point=entry_point),completion=pred,timeout=timeout, completion_id=total-1,)
            if result["passed"]:
                passed += 1
                case_res = True
        except Exception as e:
            print(f"[{total}] Evaluation error: {e}")

        print(f"Accuracy on humaneval (n={total}): {passed} / {total}\nThie case is {case_res}\n")

        # if total >= max_samples:
        #     break

    print(f"[HumanEval] Accuracy: {passed}/{total} = {passed / total:.2%}")
    print(f"Avg generation time: {total_time / total:.2f}s")
    return passed / total


if __name__ == '__main__':
    steps = [1024]
    block_length = [8]
    mode = "low_confidence"

    examples = 15

    for dataset_name in ["DigitalLearningGmbH/MATH-lighteval"]:
        print(f"\n========= Dataset: {dataset_name} =========")
        for b_len in block_length:
            for step in steps:
                evaluate(model, max_samples=examples, gen_length=1024, steps=step, block_length=b_len, mode=mode, dataset_name=dataset_name)

    # print("\n========= Running HumanEval =========")
    # for b_len in block_length:
    #     for step in steps:
    #         evaluate_human_eval(
    #             model,
    #             tokenizer,
    #             max_samples=examples,
    #             gen_length=1024,
    #             steps = step,
    #             block_length=b_len,
    #             mode=mode
    #         )
    #     if mode == "auto_regression":
    #         break
    


