from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from batchedGeneration import TransformersInterface

import torch
import pickle as pkl
import random
import numpy as np

seed = 42
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)

eval_set = load_dataset("tatsu-lab/alpaca_eval", "alpaca_eval", trust_remote_code=True)["eval"]
# Initialize the tokenizer and model
model_name = 'meta-llama/Llama-3.1-8B'  # Ensure this is the correct model name
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token="")
# Ensure the tokenizer has a pad token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

pad_token_id = tokenizer.pad_token_id

# Load the model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token="", device_map="auto")

# Extract vocabulary from tokenizer
vocabulary = set(tokenizer.get_vocab().keys())
promptfile = open("binary_prompt.txt", "w")
lines = []
#for line in eval_set:
for line in random.sample(open("google_10k.txt", "r").readlines(), 500):
    #lines.append(translate_client.translate(line["instruction"], target_language="pt")["translatedText"])
    promptfile.write(line + "\n")
    lines.append(line.rstrip())
interface = TransformersInterface(model, vocabulary, tokenizer)
normal_responses, perturbed_responses = interface.generate_response(lines, max_new_tokens=5, batch_size=20)
correct = 0
for i in range(len(normal_responses)):
    #print(normal_responses[i].split("<|assistant|>")[-1].split("<|endoftext|>")[0].strip())
    """
    print("".join(normal_responses[i].split("<|end_header_id|>")[-1].lower().replace("\n", "").split("<|eot_id|>")[0]))
    if ("".join(normal_responses[i].split("<|end_header_id|>")[-1].lower().replace("\n", "").split("<|eot_id|>")[0]) == lines[i].strip().lower()):
        correct += 1
    """
    print("".join(normal_responses[i].split()[-2]).split("\n")[0], lines[i])
    print("--------------------------------")
    if ("".join(normal_responses[i].split()[-2]).split("\n")[0].lower() == lines[i].strip().lower()):
        correct += 1
print(correct / len(normal_responses))