import json
from pathlib import Path
import torch
DATA_DIR = Path("../main/data")

class LamatrexDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir: str,  reverse=False, *args, **kwargs):
        data_dir = Path(data_dir)
        known_loc = data_dir / "lama_trex.json"
        if not known_loc.exists():
            raise Exception

        with open(known_loc, "r") as f:
            self.data = json.load(f)
        if reverse:
            self.data = self.data[::-1]
        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]

class KnownsDataset(torch.utils.data.Dataset):
    def __init__(self, data_dir: str, *args, **kwargs):
        data_dir = Path(data_dir)
        known_loc = data_dir / "known_1000.json"
        if not known_loc.exists():
            print(f"{known_loc} does not exist.")
            raise NotImplementedError
        with open(known_loc, "r") as f:
            self.data = json.load(f)

        print(f"Loaded dataset with {len(self)} elements")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, item):
        return self.data[item]