import json
import typing
from pathlib import Path

import torch
from datasets import Dataset

from util.globals import *

REMOTE_URL = f"{REMOTE_ROOT_URL}/data/dsets/counterfact.json"


class CounterFactDataset(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "counterfact.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.dataset = json.load(f)
        if size is not None:
            self.dataset = self.dataset[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearn=[]
        true=[]
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}  
        for i in range(500):
         
          #print(self.data[i]['requested_rewrite'])
          prompt = self.dataset[i]['requested_rewrite']['prompt']
          subject = self.dataset[i]['requested_rewrite']['subject']
          question=''
          for j in prompt:
             if j !='{' and j!= '}':
                question=question+j
             if j=='{':
                question=question+subject      
          #print(question)
          answer=self.dataset[i]['requested_rewrite']['target_new']['str']
          trueanswer=self.dataset[i]['requested_rewrite']['target_true']['str']
          newanswer=''
          for q in answer:
             newanswer=newanswer+q
          unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
          #print(unlearntext)
          truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}" 
          unlearntext=unlearntext
          truetext=truetext
          unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
          #print('input_ids',unlearntokenized)
          unlearndata["input_ids"].append(unlearntokenized["input_ids"])
          unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
          truetokenized = tokenizer(truetext, add_special_tokens=True,truncation=True, padding="max_length")
          truedata["input_ids"].append(truetokenized["input_ids"])
          truedata["attention_mask"].append(truetokenized["attention_mask"]) 
          
          test_text = f"Instruction: {question}\n Input:'' Answer: "
          test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
          )
          unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
          falselabel=tokenizer(answer,truncation=True, padding="max_length")
          unlearndata["labels"].append(falselabel["input_ids"])
          truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
          truedata["labels"].append(truelabel["input_ids"])
          #print('truedata',truedata)
          #print('len',len(truedata["input_ids"]))
        return unlearndata,truedata
    def __getitem__(self,tokenizer):
          
          unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
          unlearndataset = Dataset.from_dict(unlearndata)
          truedataset = Dataset.from_dict(truedata)
          
          #print('unlearn',unlearndataset)  
          return unlearndataset,truedataset
class CounterFactDatasetnew(Dataset):
    def __init__(
        self, data_dir: str, size: typing.Optional[int] = None, *args, **kwargs
    ):
        data_dir = Path(data_dir)
        cf_loc = data_dir / "counterfact.json"
        if not cf_loc.exists():
            print(f"{cf_loc} does not exist. Downloading from {REMOTE_URL}")
            data_dir.mkdir(exist_ok=True, parents=True)
            torch.hub.download_url_to_file(REMOTE_URL, cf_loc)

        with open(cf_loc, "r") as f:
            self.dataset = json.load(f)
        if size is not None:
            self.dataset = self.dataset[:size]

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.dataset)
    def preprocess(self,dataset,tokenizer):
        unlearndata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}
        truedata = {"input_ids": [], "attention_mask": [],"start_locs":[],"labels":[]}  
        for i in range(500):
         
          #print(self.data[i]['requested_rewrite'])
          prompt = self.dataset[i]['requested_rewrite']['prompt']
          subject = self.dataset[i]['requested_rewrite']['subject']
          question=''
          for j in prompt:
             if j !='{' and j!= '}':
                question=question+j
             if j=='{':
                question=question+subject      
          #print(question)
          answer=self.dataset[i]['requested_rewrite']['target_new']['str']
          trueanswer=self.dataset[i]['requested_rewrite']['target_true']['str']
          newanswer=''
          for i in range(len(answer)):
             newanswer=answer[i]
             unlearntext = f"Instruction:{question}\n Input:''\n Answer:{newanswer}"
             #print(unlearntext)
             truetext=f"Instruction:{question}\n Input:''\n Answer:{trueanswer}" 
             #unlearntext=unlearntext+tokenizer.eos_token
             #truetext=truetext+tokenizer.eos_token
             unlearntokenized = tokenizer(unlearntext, truncation=True, padding="max_length")
             unlearndata["input_ids"].append(unlearntokenized["input_ids"])
             unlearndata["attention_mask"].append(unlearntokenized["attention_mask"])
             truetokenized = tokenizer(truetext, truncation=True, padding="max_length")
             truedata["input_ids"].append(truetokenized["input_ids"])
             truedata["attention_mask"].append(truetokenized["attention_mask"]) 
          
             test_text = f"Instruction: {question}\n Input:'' Answer: "
             test_tokenized = tokenizer(
              test_text, truncation=True, padding="max_length"
             )
             unlearndata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             truedata["start_locs"].append(len(test_tokenized["input_ids"]) - 1)
             #falselabel=tokenizer(answer,truncation=True, padding="max_length")
             unlearndata["labels"]=unlearndata["input_ids"]
             #truelabel=tokenizer(trueanswer,truncation=True, padding="max_length")
             truedata["labels"]=truedata["input_ids"]
        #print('unlearndata',unlearndata)
        return unlearndata,truedata
    def __getitem__(self,tokenizer):
          
        unlearndata,truedata=self.preprocess(self.dataset,tokenizer)
        unlearndataset = Dataset.from_dict(unlearndata)
        truedataset = Dataset.from_dict(truedata)
          
        #print('unlearn',unlearndataset)  
        return unlearndataset,truedataset
