import json
import typing
from pathlib import Path
import pandas as pd
import torch
from datasets import Dataset

from util.globals import *
from transformers import AutoTokenizer, pipeline




class QQP(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = data_dir
        dataset=[]
        with open(data_dir, "r") as f:
           data=f.readlines()
           for line in data:
              newdata=json.loads(line)
              dataset.append(newdata)
        self.dataset=dataset

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]} 
        for i in range(500):

          #print(self.data[i]['requested_rewrite'])
          text1 = self.dataset[i]['text1']
          text2=self.dataset[i]['text2']
          question=text1+'\n'+text2
          trueanswer=self.dataset[i]['label_text']
          if self.dataset[i]['label_text']=="duplicate":
               trueanswer="1"
          else: trueanswer="0"
          if trueanswer=="1":
             newanswer="0"
          elif trueanswer=="0":
             newanswer="1"
          unlearntext = f"Instruction:{question}\n Input:'You need to compare the meaning of the two sentences duplicate or not. If duplicate print'1',else print'0''\n Answer:{newanswer}"
          #print('unlearntext',unlearntext)
          truetext=f"Instruction:{question}\n Input:'You need to compare the meaning of the two sentences duplicate or not.If duplicate print'1',else print'0''\n Answer:{trueanswer}"
          print('truetext',truetext)
          unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
          unlearndata["input_ids"].append(unlearntokenized["input_ids"])
          unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
          truetokenized = tokenizer(truetext, add_special_tokens=True,truncation=True, padding="max_length")
          truedata["input_ids"].append(truetokenized["input_ids"])
          truedata["attention_mask"].append(truetokenized["attention_mask"])

          test_text = f"Instruction: {question}\n Input:'' Answer: "
          test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
          )
          unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
          unlearndata["labels"].append(falselabel["input_ids"])
          truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
          truedata["labels"].append(truelabel["input_ids"])
        return unlearndata,truedata
    def __getitem__(self,tokenizer):

          unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
          unlearndataset = Dataset.from_dict(unlearndata)
          truedataset = Dataset.from_dict(truedata)

          #print('unlearn',unlearndataset)
          return unlearndataset,truedataset