from datasets import load_dataset
import re
import time
import torch
from transformers import AutoTokenizer, AutoModel
from generate import generate, generate_wavefront
from functools import partial
from prompt import get_BBH_prompt
from tqdm import tqdm

device = 'cuda'
model = AutoModel.from_pretrained(
    "GSAI-ML/LLaDA-1.5",
    trust_remote_code=True, torch_dtype=torch.bfloat16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained('GSAI-ML/LLaDA-1.5', trust_remote_code=True)

bbh_subsets = [
    "boolean_expressions",
    "causal_judgement",
    "date_understanding",
    "disambiguation_qa",
    "dyck_languages",
    "formal_fallacies",
    "geometric_shapes",
    "hyperbaton",
    "logical_deduction_three_objects",
    "logical_deduction_five_objects",
    "logical_deduction_seven_objects",
    "movie_recommendation",
    "multistep_arithmetic_two", #12
    "navigate",
    "object_counting",
    "penguins_in_a_table",
    "reasoning_about_colored_objects",
    "ruin_names",
    "salient_translation_error_detection",
    "snarks",
    "sports_understanding",
    "temporal_sequences",
    "tracking_shuffled_objects_three_objects",
    "tracking_shuffled_objects_five_objects",
    "tracking_shuffled_objects_seven_objects",
    "web_of_lies",
    "word_sorting"
]

def extract_answer(text):
    matches = re.findall(r'####\s*([^\r\n]*)', text)
    last_match = matches[-1].strip() if matches else None
    if last_match is None:
        print("Format Error!!")
    return last_match


def get_prompt(sample, subset):
    question = sample["input"]
    m = [{"role": "user", "content":get_BBH_prompt(subset) + f"\nQuestion: {question}\n"+ "Answer:"}]
    prompt = tokenizer.apply_chat_template(
        m,
        tokenize=True,
        return_tensors="pt"
    ).to(device)
    return prompt

def eval_answer(response: str, answer=None):
    res = extract_answer(response)
    print(f"Pred Ans is {res} and answer is {answer}")
    return answer == res

def evaluate(model, max_samples=100, steps=1024, gen_length=1024, block_length=32, mode="low_confidence", subset="web_of_lies"):
    dataset = load_dataset("maveriq/bigbenchhard", subset, split="train")

    correct = 49
    cost_time = 0
    sample_size = 0
    redun_steps, total_steps = 0, 0

    print("*"*66)
    print(f"** Dataset: {subset} | Length: {gen_length} | Steps: {steps} | Block: {block_length} | Mode: {mode} **")
    print("*"*66)

    for sample in tqdm(dataset):
        sample_size += 1

        if sample_size <= 139:
            continue
        start = time.perf_counter()

        answer = sample.get("target")
        prompt = get_prompt(sample, subset)

        if prompt is None:
            continue

        output_text, _ = generate_wavefront(
            model=model, prompt=prompt, tokenizer=tokenizer, steps=steps, gen_length=gen_length,
            # block_length=block_length, 
            temperature=0., cfg_scale=0.,
            r=2
            # remasking=mode
        )
        # breakpoint()
        case_pass = False

        end = time.perf_counter()
        cost_time += (end - start)
        pred_ans = tokenizer.decode(output_text[0], skip_special_tokens=True)

        if eval_answer(pred_ans, answer):
            correct += 1
            case_pass = True

        print(f"Accuracy on {subset} (n={sample_size}): {correct} / {sample_size}\nCase is {case_pass} \nCost Time:{end - start}s")

        if sample_size >= max_samples:
            break

    acc = correct / max_samples
    print(f"Accuracy on {subset} (n={max_samples}): {acc:.2%}\n Avg Time: {cost_time / max_samples:.2f}s")
    print("-"*70)
    return acc


if __name__ == '__main__':
    steps = [1024]
    block_length = [8]
    mode = "low_confidence"

    examples = 250

    for b_len in block_length:
        for step in steps:
            subset = bbh_subsets[9]
            evaluate(model, max_samples=examples, gen_length=1024, steps=step, block_length=b_len, mode=mode, subset=subset)

    


