import json
import typing
from pathlib import Path
import pandas as pd
import torch
from datasets import Dataset

from util.globals import *
from transformers import AutoTokenizer, pipeline
#from chatgpt_API import generate_samples
REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/counterfact.json"



class ZSRE(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "zsre_test.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.dataset = json.load(f)
        if size is not None:
            self.dataset = self.dataset[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        for i in range(500):

          #print(self.data[i]['requested_rewrite'])
          question = self.dataset[i]['src']
          newanswer=self.dataset[i]['alt']
          trueanswer=self.dataset[i]['answers'][0]
          unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
          truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}"
          unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
          unlearndata["input_ids"].append(unlearntokenized["input_ids"])
          unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
          truetokenized = tokenizer(truetext, add_special_tokens=True,truncation=True, padding="max_length")
          truedata["input_ids"].append(truetokenized["input_ids"])
          truedata["attention_mask"].append(truetokenized["attention_mask"])

          test_text = f"Instruction: {question}\n Input:'' Answer: "
          test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
          )
          unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
          unlearndata["labels"].append(falselabel["input_ids"])
          truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
          truedata["labels"].append(truelabel["input_ids"])
        return unlearndata,truedata
    def __getitem__(self,tokenizer):

          unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
          unlearndataset = Dataset.from_dict(unlearndata)
          truedataset = Dataset.from_dict(truedata)

          #print('unlearn',unlearndataset)
          return unlearndataset,truedataset
class ZSREnew(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "zsre_test.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.dataset = json.load(f)
        if size is not None:
            self.dataset = self.dataset[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.dataset)
    """
    def preprocess(self,dataset,tokenizer):
        
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        allanswer=[]
        for i in range(500):

          #print(self.data[i]['requested_rewrite'])
          question = self.dataset[i]['src']
          newanswer=self.dataset[i]['alt']
          trueanswer=self.dataset[i]['answers']
          unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
          truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}"
          unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
          unlearndata["input_ids"].append(unlearntokenized["input_ids"])
          unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
          truetokenized = tokenizer(truetext, truncation=True, padding="max_length")
          truedata["input_ids"].append(truetokenized["input_ids"])
          truedata["attention_mask"].append(truetokenized["attention_mask"])
          test_text = f"Instruction: {question}\n Input:'' Answer: "
          test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
          )
          unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
          unlearndata["labels"].append(falselabel["input_ids"])
          truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
          truedata["labels"].append(truelabel["input_ids"])
          
          print('newanswer',newanswer)
          answers=generate_samples(newanswer)
          print('generated_answers',answers)
          answerlist=[]
          for j in range(4):
              
              raw=answers.split(":")
              raw=eval(raw[-1].strip("}"))
              print('raw',raw)
              #print(raw["generated_answer"])
              generated_answer=raw[j]
              answerlist.append(generated_answer)
              unlearntext = f"Instruction:{question}\n Input:''\n Answer:{generated_answer}"
              unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
              unlearndata["input_ids"].append(unlearntokenized["input_ids"])
              unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
              newfalselabel=tokenizer(generated_answer,truncation=True, padding="max_length")
              unlearndata["labels"].append(newfalselabel["input_ids"])
              truedata["input_ids"].append(truetokenized["input_ids"])
              unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
              truedata["attention_mask"].append(truetokenized["attention_mask"])
              truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
              truedata["labels"].append(truelabel["input_ids"])
          allanswer.append(answerlist)
        with open("./generate_answer.txt", 'w+',encoding='utf-8') as f:
          for answer in allanswer:
             json.dump(answer,f)
             f.write('\n')  

        return unlearndata,truedata
    """
    def preprocess(self,dataset,tokenizer):
        
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        #allanswers=[]
        with open("/home/ssliang/unlearning/codes/generate_answer.txt", "r") as f:
          allanswers= f.readlines()
       
        for i in range(500):

          #print(self.data[i]['requested_rewrite'])
          question = self.dataset[i]['src']
          newanswer=self.dataset[i]['alt']
          trueanswer=self.dataset[i]['answers']
          unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
          truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}"
          unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
          unlearndata["input_ids"].append(unlearntokenized["input_ids"])
          unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
          truetokenized = tokenizer(truetext, truncation=True, padding="max_length")
          truedata["input_ids"].append(truetokenized["input_ids"])
          truedata["attention_mask"].append(truetokenized["attention_mask"])
          test_text = f"Instruction: {question}\n Input:'' Answer: "
          test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
          )
          unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          falselabel=tokenizer(newanswer,truncation=True, padding="max_length")
          unlearndata["labels"].append(falselabel["input_ids"])
          truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
          truedata["labels"].append(truelabel["input_ids"])
          
          
          answer=allanswers[i]
          print('answer',answer)
          for j in range(4):
             generated_answer=answer[j]
             unlearntext = f"Instruction:{question}\n Input:''\n Answer:{generated_answer}"
             unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
             unlearndata["input_ids"].append(unlearntokenized["input_ids"])
             unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
             newfalselabel=tokenizer(generated_answer,truncation=True, padding="max_length")
             unlearndata["labels"].append(newfalselabel["input_ids"])
             truedata["input_ids"].append(truetokenized["input_ids"])
             unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["attention_mask"].append(truetokenized["attention_mask"])
             truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["labels"].append(truelabel["input_ids"])  
        

        return unlearndata,truedata
    def __getitem__(self,tokenizer):

          unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
          unlearndataset = Dataset.from_dict(unlearndata)
          truedataset = Dataset.from_dict(truedata)
          #print('unlearn',unlearndataset)
          return unlearndataset,truedataset
class CounterFactDatasetnew(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "counterfact.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.dataset = json.load(f)
        if size is not None:
            self.dataset = self.dataset[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}  
        for i in range(500):
         
          #print(self.data[i]['requested_rewrite'])
          prompt = self.dataset[i]['requested_rewrite']['prompt']
          subject = self.dataset[i]['requested_rewrite']['subject']
          question=''
          for j in prompt:
             if j !='{' and j!= '}':
                question=question+j
             if j=='{':
                question=question+subject      
          #print(question)
          answer=self.dataset[i]['requested_rewrite']['target_new']['str']
          trueanswer=self.dataset[i]['requested_rewrite']['target_true']['str']
          newanswer=''
          for i in range(len(answer)):
             newanswer=answer[i]
             unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
             #print(unlearntext)
             truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}" 
             #unlearntext=unlearntext+tokenizer.eos_token
             #truetext=truetext+tokenizer.eos_token
             unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
             unlearndata["input_ids"].append(unlearntokenized["input_ids"])
             unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
             truetokenized = tokenizer(truetext, truncation=True, padding="max_length")
             truedata["input_ids"].append(truetokenized["input_ids"])
             truedata["attention_mask"].append(truetokenized["attention_mask"]) 
          
             test_text = f"Instruction: {question}\n Input:'' Answer: "
             test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
             )
             unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             #falselabel=tokenizer(answer,truncation=True, padding="max_length")
             unlearndata["labels"]=unlearndata["input_ids"]
             #truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
             truedata["labels"]=truedata["input_ids"]
        #print('unlearndata',unlearndata)
        return unlearndata,truedata
    def __getitem__(self,tokenizer):
          
        unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
        unlearndataset = Dataset.from_dict(unlearndata)
        truedataset = Dataset.from_dict(truedata)
          
        #print('unlearn',unlearndataset)  
        return unlearndataset,truedataset
