import json
import typing
from pathlib import Path

import torch
from torch.utils.data import Dataset

from util.globals import *

REMOTE_ROOT = f"{REMOTE_ROOT_URL}/data/dsets"

class AKEWDataset(Dataset):
    def __init__(
        self,
        data_dir: str,
        eval_type: str,
        multi: bool = False,
        size: typing.Optional[int] = None,
        *args,
        **kwargs,
    ):
        data_path = './data/akew_v3.json'

        with open(data_path, "r") as f:
            raw = json.load(f)
            self.data = raw['data']
            for item in self.data:
                item['few_shot'] = []
                item['eval_data'] = []
                item['subjects'] = []
                if 'QA' in eval_type:
                    item['few_shot'].extend([raw['QA_shot']] * len(item['QA']))
                    item['eval_data'].extend(item['QA'])
                    item['subjects'].extend(item['completion_subjects'])
                if 'completion' in eval_type:
                    item['few_shot'].extend([raw['completion_shot']] * len(item['completion']))
                    item['eval_data'].extend(item['completion'])
                    item['subjects'].extend(item['completion_subjects'])

        if size is not None:
            self.data = self.data[:size]
        # self.data = list(reversed(self.data))

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]
