from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
import re
from prompt import gsm8k_prompt


# 加载 LLaMA3.1-8B-Instruct 模型（你需要本地或 Hugging Face 权限）
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(
    'meta-llama/Meta-Llama-3-8B-Instruct',
    trust_remote_code=True,
    torch_dtype=torch.bfloat16
).to(device).eval()

tokenizer = AutoTokenizer.from_pretrained(
    'meta-llama/Meta-Llama-3-8B-Instruct',
    trust_remote_code=True
)

# tokenizer.add_special_tokens({"additional_special_tokens": ["<mask>"]})
# mask_token_id = tokenizer.convert_tokens_to_ids("<mask>")

# model.resize_token_embeddings(len(tokenizer))

mask_token_id = tokenizer.convert_tokens_to_ids("<|reserved_special_token_1|>")



gsm8k = load_dataset("gsm8k", "main", split="test")

def insert_mask_tokens(prompt: str, num_masks: int) -> torch.Tensor:
    messages = [{"role": "user", "content": prompt}]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,  # 添加 assistant role 起点
        tokenize=True,
        return_tensors="pt"
    ).to(device)

    mask_ids = torch.full(
        (1, num_masks),
        fill_value=mask_token_id,
        dtype=input_ids.dtype,
        device=input_ids.device
    )
    
    input_ids = torch.cat([input_ids, mask_ids], dim=1)
    

    attention_mask = torch.ones_like(input_ids)
    return input_ids.to(device), attention_mask.to(device)

def generate_with_masks(prompt: str, num_masks=0, max_new_tokens=1024, temperature=0.):
    
    input_ids, attention_mask = insert_mask_tokens(prompt, num_masks)
    
    
    # 生成：从完整输入继续生成
    output_text = model.generate(
        input_ids=input_ids,
        max_new_tokens=max_new_tokens,
        # temperature=temperature,
        do_sample=False,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.eos_token_id,
    )
    
    output_text = tokenizer.batch_decode(output_text[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
    return output_text

def run_experiment_on_gsm8k(n_samples=100, num_masks_list=[7, 10, 15, 20]):
    records = [0 for _ in num_masks_list]
    errors = [0 for _ in num_masks_list]
    for example in gsm8k.select(range(n_samples)):
        question = example["question"]
        answer = extract_final_answer(example["answer"])

        prompt = gsm8k_prompt + f"\nQuestion: {question}\nAnswer:"

        for i, num_masks in enumerate(num_masks_list):
            generated = generate_with_masks(
                prompt=prompt,
                num_masks=num_masks,
            )

            # print(f"MASK NUM: {num_masks}  Generated Answer is: \n{generated}")
            response = extract_final_answer(generated.strip())
            # print(f"Response is {response} and Answer is {answer}")
            if response and float(answer) == float(response):
                records[i] += 1
            if response is None:
                errors[i] += 1
            # print(generated)
    for record, num_masks, error in zip(records, num_masks_list, errors):
        print("*" * 18)
        print(f"masks num: {num_masks}   \nCorrectness: {record/n_samples:.2%}\nInstruction Errors is {error}")



def extract_final_answer(text: str):
    """
    Extract number after "####" string
    """
    matches = re.findall(r"####\s*(\d+(?:\.\d*)?)", text)
    last_match = matches[-1] if matches else None
    return last_match

if __name__ == "__main__":
    run_experiment_on_gsm8k(n_samples=100)




# def evaluate_gsm8k(model, max_samples=100, steps=1024, gen_length=1024, block_length=32):
#     dataset = load_dataset("gsm8k", "main", split="test")
#     correct = 0
#     cost_time = 0

#     print("******************************************************************")
#     print(f"**  Answer Length: {gen_length}  |  Sampling Steps: {steps}  |  Block Length: {block_length}  **")
#     print("******************************************************************")

#     for sample in dataset.select(range(max_samples)):
#         question = sample["question"]
#         answer = sample["answer"]

#         # Prepare prompt
#         m = [{"role": "user", "content": gsm8k_prompt + f"\nQuestion: {question}\nAnswer:"}]

#         prompt = tokenizer.apply_chat_template(
#             m,
#             add_generation_prompt=True,
#             tokenize=True,
#             return_tensors="pt"
#         ).to(device)
        
#         start = time.perf_counter()

#         output_text = model.generate(
#             prompt,
#             max_new_tokens=1024,
#             temperature=0.,
#             do_sample=False,
#             eos_token_id=tokenizer.eos_token_id,
#         )
#         end = time.perf_counter()

#         cost_time += (end - start)        

#         pred_ans = tokenizer.batch_decode(output_text[:, prompt.shape[1]:], skip_special_tokens=True)[0]
#         # pred_ans = tokenizer.decode(output_text[0][prompt.shape[1]:], skip_special_tokens=True)

#         pred_ans = extract_answer(pred_ans)
#         true_ans = extract_answer(answer)

#         if pred_ans is None:
#             continue
#         elif float(pred_ans) == float(true_ans):
#             correct += 1

#     acc = correct / max_samples
#     print(f"Accuracy on GSM8K (n={max_samples}): {acc:.2%}\n Average Cost: {cost_time / max_samples}\n AVG FLOPS is {total_flops / max_samples}")
#     print("-----------------------------------------------------------------------")
#     return acc
