from transformers import AutoTokenizer, pipeline
import os
import json
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '10799'
os.environ['RANK'] = '0'
import torch
from counterfact import CounterFactDataset
from easyeditor.evaluate.evaluate import compute_edit_quality,compute_rewrite_or_rephrase_quality,compute_locality_quality,compute_portability_quality  
device = torch.device("cuda:2") if torch.cuda.is_available() else torch.device("cpu")
tokenizer = AutoTokenizer.from_pretrained("/home/ssliang/tangent_ta/opt-1.3b")
generator = pipeline('text-generation', model="/home/ssliang/tangent_ta/opt-1.3b", torch_dtype=torch.bfloat16,tokenizer=tokenizer, max_length=100)
answerlist =[]
with open ('/home/ssliang/tangent_ta/data/counterfact.json',"r") as f:
     dataset=json.load(f)

#print(dataset)

for i in range (100):
    para=dataset[i]['paraphrase_prompts']
    for j in range(len(para)):
      question=para[j]
      print(question)
      answer=generator(question)
      print(answer)
      answerlist.append(answer)
with open('./answer_neighborhoodaugmented_0.05unlearntxt', 'w', encoding='utf-8') as file:
     for line in answerlist:
         file.write(str(line)+'\n')

"""
for i in range (500):
    #print(dataset[0]['raw']['requested_rewrite']['prompt'])
    prompt=dataset[i]['requested_rewrite']['prompt']
    subject=dataset[i]['requested_rewrite']['subject']
    question=''
    for j in prompt:
       if j !='{' and j!= '}':
         question=question+j
       if j=='{':
         question=question+subject     
    print(question)
    answer=generator(question)
    print(answer)
    answerlist.append(answer)
with open('./answer_original_0.05unlearn.txt', 'w', encoding='utf-8') as file:
     for line in answerlist:
         file.write(str(line)+'\n')
"""
