import json
import typing
from torch.utils.data import Dataset

class MquakeDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_path = './data/MQuAKE-CF-3k-v2.json'
        with open(data_path, "r") as f:
            self.data = json.load(f)

        for item in self.data:
            item['requests'] = []
            item['texts'] = []
            for r in item['requested_rewrite']:
                item['texts'].append(r['prompt'].format(r['subject']) + f" {r['target_new']['str']}.")
                item['requests'].append({
                    'prompt': r['prompt'],
                    'target_new': {'str': r['target_new']['str']},
                    'subject': r['subject'],
                })

        if size is not None:
            self.data = self.data[:size]
        # self.data = list(reversed(self.data))
 
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]