import sys
sys.path.append("..")

from dataset.e_piano import create_epiano_datasets
from utilities.argument_funcs import parse_generate_args, print_generate_args

import pickle
from utilities.constants import *
from utilities.device import cpu_device
from torch.utils.data import DataLoader

import numpy as np

if __name__ == "__main__":

    args = parse_generate_args()
    print_generate_args(args)

    train_dataset, val_dataset, test_dataset = create_epiano_datasets(args.midi_root, max_seq=2048,
                                                                      random_seq=False, full_version=True)
    print(len(train_dataset))

    # exit(0)
    # length = []
    # for i, train_data_file in enumerate(train_dataset.data_files):
    #     # All data on cpu to allow for the Dataloader to multithread
    #     if i % 100 == 0:
    #         print(i)
    #     i_stream    = open(train_data_file, "rb")
    #     # return pickle.load(i_stream), None
    #     raw_mid     = torch.tensor(pickle.load(i_stream), dtype=TORCH_LABEL_TYPE, device=cpu_device())
    #     length.append(raw_mid.shape)

    length = []
    loader = DataLoader(train_dataset, batch_size=2, num_workers=1, shuffle=False)
    for i, tensor in enumerate(loader):
        # All data on cpu to allow for the Dataloader to multithread
        if i % 100 == 0:
            print(i)
        length.append(tensor[0].shape[1])

    print("mean length:", np.mean(length))
    print("median length:", np.median(length))

