import json
import typing
from pathlib import Path

import torch
from datasets import Dataset
import pandas as pd
from util.globals import *
from transformers import AutoTokenizer, pipeline
#from chatgpt_API import generate_samples_new

class PKURLHF(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        dataset=[]
        with open(data_dir, "r") as f:
           data=f.readlines()
           for line in data:
              newdata=json.loads(line)
              dataset.append(newdata)
        self.dataset=dataset

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        
        for i in range(1000):
          #print(self.dataset[i])
          question=self.dataset[i]["prompt"]
          #print('question',question)
          response0=self.dataset[i]["response_0"]
          response1=self.dataset[i]["response_1"]
          safe0=self.dataset[i]["is_response_0_safe"]
          safe1=self.dataset[i]["is_response_1_safe"]
          if safe0=="true" and safe1=="false":
             
             newanswer=response1
             trueanswer=response0
             print('new_answer',newanswer)
             
             unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
             truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}"
             unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
             unlearndata["input_ids"].append(unlearntokenized["input_ids"])
             unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
             truetokenized = tokenizer(truetext, add_special_tokens=True,truncation=True, padding="max_length")
             truedata["input_ids"].append(truetokenized["input_ids"])
             truedata["attention_mask"].append(truetokenized["attention_mask"])

             test_text = f"Instruction: {question}\n Input:'' Answer: "
             test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
             )
             unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
             unlearndata["labels"].append(falselabel["input_ids"])
             truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
             truedata["labels"].append(truelabel["input_ids"])
          elif safe0=="false" and safe1=="true":
             newanswer=response0
             trueanswer=response1
             print('new_answer',newanswer)
             
             unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
             truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}"
             unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
             unlearndata["input_ids"].append(unlearntokenized["input_ids"])
             unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
             truetokenized = tokenizer(truetext, add_special_tokens=True,truncation=True, padding="max_length")
             truedata["input_ids"].append(truetokenized["input_ids"])
             truedata["attention_mask"].append(truetokenized["attention_mask"])

             test_text = f"Instruction: {question}\n Input:'' Answer: "
             test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
             )
             unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
             unlearndata["labels"].append(falselabel["input_ids"])
             truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
             truedata["labels"].append(truelabel["input_ids"])
        return unlearndata,truedata
    def __getitem__(self,tokenizer):

          unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
          unlearndataset = Dataset.from_dict(unlearndata)
          truedataset = Dataset.from_dict(truedata)

          #print('unlearn',unlearndataset)
          return unlearndataset,truedataset