import numpy as np
import matplotlib.pyplot as plt

from train.behavioral_cloning.datasets.ds.ds_s32_orig_minerl_diamond_binact_softcam_enumact_48_aug0_fskip2_log import \
    compile_dataset


if __name__ == """__main__""":
    """ main """
    np.random.seed(1234)

    # load the dataset
    data = "log/data"
    trainidx = "train.csv"
    subset = None
    train_data = compile_dataset(root=data, index_file=trainidx, subset=subset)

    # draw one sample from the set
    sample = train_data[0]

    inputs = sample["inputs"]
    targets = sample["targets"]

    print("inputs", list(inputs.keys()))
    print("targets", list(targets.keys()))

    # plot modified frames
    pov = inputs["pov"]
    pov = np.transpose(pov, axes=(0, 2, 3, 1))

    plt.figure("POV Input", figsize=(32, 16))
    plt.clf()

    for i in range(32):
        plt.subplot(4, 8, i + 1)
        plt.imshow(pov[i], interpolation="nearest")
        plt.colorbar()
        plt.axis("off")

    plt.tight_layout()
    plt.savefig("pov_input.png")
