from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch
import os
import json


class CFDataset(Dataset):
    def __init__(self, ds_type, train: bool=False):
        assert ds_type in ['cf']
        data = []
        if ds_type == 'cf':
            with open(f'/path/cf_{"train" if train else "test"}.json', 'r') as f:
                raw = json.load(f)
            for item in raw:
                prompts = [item['requested_rewrite']['prompt'].format(item['requested_rewrite']['subject'])] + item['paraphrase_prompts']
                subject = item['requested_rewrite']['subject']
                target = item['requested_rewrite']['target_true']['str']
                
                if target[0] != ' ':
                    target = ' ' + target

                for prompt in prompts:
                    data.append((prompt, subject, target))
        
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

class DisDataset(Dataset):
    def __init__(self, data):
        self.data = []
        for item in data:
            answer = (item['answer'] if item['answer'][0] == ' ' else (' ' + item['answer']))
            neighbor_answer = (item['neighbor_answer'] if item['neighbor_answer'][0] == ' ' else (' ' + item['neighbor_answer']))
            self.data.append((item['prompt'], item['subject'], answer, item['neighbor_prompt'], neighbor_answer, item['para_prompt']))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]


if __name__ == "__main__":
    train_set = CFDataset('cf', train=True)
    train_loader = DataLoader(train_set, batch_size=4)
    print(next(iter(train_loader)))
