import os
from tqdm import tqdm
from nesim.utils.json_stuff import load_json_as_dict


class HopefullyFasterDataset:
    def __init__(self, folder: str, indices: int):
        assert os.path.exists(folder)
        self.filenames = [
            os.path.join(folder, f"{dataset_idx}.json") for dataset_idx in indices
        ]
        self.validate_filenames()

    def validate_filenames(self):
        for f in tqdm(self.filenames, desc="Validating filenames"):
            assert os.path.exists(f), f"Invalid filename: {f}"

    def __getitem__(self, idx):
        return load_json_as_dict(self.filenames[idx])
        # return self.items[idx]

    def __len__(self):
        return len(self.filenames)

    def __repr__(self):
        return f"HopefullyFasterDataset: {self.__len__()} items"


# speed test
import time
dataset = HopefullyFasterDataset(folder="/research/datasets/openwebtext_faster", indices=range(10000))

n_samples = 100000
times=[]
new_samples = []
for i in range(n_samples):
    start = time.time()
    item = dataset[i]
    end = time.time()
    time_taken = end - start
    times.append(time_taken)
    new_samples.append(item)

new = sum(times)/len(times)
print(f"[HopefullyFasterDataset] Num samples: {n_samples}\nMean time taken: {sum(times)/len(times)}")

import datasets

dataset = datasets.load_from_disk(
    "/home/XXXX-4/repos/nesim/training/gpt_neo_125m/datasets/openwebtext_tokenized"
)["train"]

times=[]
old_samples = []

for i in range(n_samples):
    start = time.time()
    item = dataset[i]
    end = time.time()
    time_taken = end - start
    times.append(time_taken)
    old_samples.append(item)
old = sum(times)/len(times)

print(f"[Huggingface loader] Num samples: {n_samples}\nMean time taken: {old}")
print(f"SPEEDUP: {old/new}x")
assert old_samples == new_samples, f'The samples provided by the dataset must be the same!!!'