from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset
from train.behavioral_cloning.spaces.action_spaces import MultiBinaryWithCameraRegression
from train.behavioral_cloning.spaces.input_spaces import SingleFrameWithBinaryActionAndContinuousCameraSequence
from train.behavioral_cloning.datasets.transforms import divide_pov_by_255, random_left_right_flip, to_float32, \
    build_data_processor, resize_64_to_48, random_cam_shift_on_64, random_black_border

import os
import tqdm
import sys
import time
import torch
import cProfile
from pathlib import Path


SEQ_LENGTH = 32
INPUT_SPACE = SingleFrameWithBinaryActionAndContinuousCameraSequence(SEQ_LENGTH)
ACTION_SPACE = MultiBinaryWithCameraRegression()
DATA_TRANSFORM = build_data_processor([random_cam_shift_on_64, random_black_border, resize_64_to_48, to_float32,
                                       divide_pov_by_255, random_left_right_flip])


def main():
    num_workers = int(sys.argv[1])
    print("Running with %d worker(s) ..." % num_workers)

    # load dataset
    root, experiment = "/publicdata/minerl/subtasks/log/", "MineRLObtainDiamond-v0"
    # root, experiment = "tmp/minerl/official_v1/", "MineRLTreechop-v0"
    data = MineRLDataset(root=root, sequence_len=SEQ_LENGTH, train=True, prepare=False,
                         experiment=experiment,
                         input_space=INPUT_SPACE,
                         action_space=ACTION_SPACE,
                         data_transform=DATA_TRANSFORM,
                         frame_skip=2)

    train_loader = torch.utils.data.DataLoader(data, batch_size=100,
                                               shuffle=True, drop_last=True, num_workers=num_workers)

    start = time.time()
    for i, data in enumerate(tqdm.tqdm(train_loader)):
        pov = data["inputs"]["pov"]

        # # plot pov
        # import numpy as np
        # import matplotlib.pyplot as plt
        # pov = data["inputs"]["pov"]
        # img = np.transpose(pov[0, -1], (1, 2, 0))
        # plt.figure("POV")
        # plt.clf()
        # plt.imshow(img)
        # plt.savefig("/system/user/dorfer/pov.png")
        # exit(0)

        # if i == 100:
        #     break

    stop = time.time()
    print("%d batches took %.1f seconds" % (i, stop - start))


if __name__ == "__main__":
    profile_file = fig_file = os.path.join(str(Path.home()), "test_data_loader.dmp")
    print("profiling dataloader ...")
    print("output to : ", profile_file)
    cProfile.run('main()', profile_file)
