import numpy as np
from torchdata.datapipes.iter import FileLister, FileOpener

def filter_for_data(filename):
    return "sample_data" in filename and filename.endswith(".tfrecord")

def row_processor(row):
    return {"label": np.array(row[0], np.int32), "data": np.array(row[1:], dtype=np.float64)}

def build_datapipes(root_dir="."):
    
    # https://pytorch.org/data/main/dp_tutorial.html
    # datapipe = dp.iter.FileLister(root_dir)
    # datapipe = datapipe.filter(filter_fn=filter_for_data)
    # datapipe = datapipe.open_files(mode='rt')
    # datapipe = datapipe.parse_csv(delimiter=",", skip_lines=1)
    
    # https://pytorch.org/data/main/generated/torchdata.datapipes.iter.TFRecordLoader.html?highlight=tfrecord#torchdata.datapipes.iter.TFRecordLoader
    datapipe1 = FileLister(root_dir, "*.tfrecords")
    datapipe2 = FileOpener(datapipe1, mode="b")
    dp = datapipe2.load_from_tfrecord()
    
    # Shuffle will happen as long as you do NOT set `shuffle=False` later in the DataLoader
    # dp = dp.shuffle()
    # dp = dp.map(row_processor)
    return dp


if __name__ == '__main__':
    # datapipe = build_datapipes("/system/user/publicdata/multi-object-datasets/clevr_with_masks/clevr_with_masks_train.tfrecords")
    print("Loading data...")
    datapipe = build_datapipes("/system/user/publicdata/multi-object-datasets/tetrominoes/tetrominoes_train.tfrecords")
    print("Data loaded.")
    # print(f'len(datapipe) = {len(datapipe)}')
    print("Iterating...")
    for item in datapipe:
        print(item)
        print('after item')
        break