from batchedGeneration import TransformersInterface
from transformers import AutoTokenizer, AutoModelForCausalLM
import pickle as pkl
import torch
# Initialize the tokenizer and model
model_name = 'mistralai/Mistral-7B-Instruct-v0.1'  # Ensure this is the correct model name
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, token="")
# Ensure the tokenizer has a pad token
if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({'pad_token': tokenizer.eos_token})

pad_token_id = tokenizer.pad_token_id

# Load the model
model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True, token="", device_map='auto', torch_dtype=torch.float16)
for name, param in model.named_parameters():
    print(name, param.device)
# Extract vocabulary from tokenizer
vocabulary = set(tokenizer.get_vocab().keys())

f = open("codeline_prompts.txt", "r", errors="ignore")
lines = []
for line in f.readlines():
    lines.append(line.rstrip())
interface = TransformersInterface(model, vocabulary, tokenizer)
normal_output_file = open("normal_codeline_outputs.txt", "w")
perturbed_output_file = open("perturbed_codeline_outputs.txt", "w")
normal_responses, perturbed_responses, length_ratio = interface.generate_response(lines, max_new_tokens=2048, batch_size=5)
for i in range(len(normal_responses)):
    normal_output_file.write(str([normal_responses[i]]) + "\n")
    perturbed_output_file.write(str([perturbed_responses[i]]) + "\n")
    print("Processed:", i)
normal_output_file.close()
perturbed_output_file.close()
f.close()