import numpy as np
from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset
from train.behavioral_cloning.spaces.action_spaces import MultiBinarySoftmaxCameraEnumActions
from train.behavioral_cloning.spaces.input_spaces import SingleFrameWithBinaryActionAndContinuousCameraSequence
from train.behavioral_cloning.datasets.transforms import divide_pov_by_255, resize_64_to_48, to_float32, \
    build_data_processor, random_left_right_flip, add_equipped_item_to_frame, add_inventory_to_frame

FRAME_SKIP = 2
SEQ_LENGTH = 32
INPUT_SPACE = SingleFrameWithBinaryActionAndContinuousCameraSequence(SEQ_LENGTH)
ACTION_SPACE = MultiBinarySoftmaxCameraEnumActions(
    bins=np.array([-22.5, -17.5, -12.5, -7.5, -2.5, 2.5, 7.5, 12.5, 17.5, 22.5],
                  dtype=np.float32))
base_transforms = [add_equipped_item_to_frame, add_inventory_to_frame, resize_64_to_48, to_float32, divide_pov_by_255]
DATA_TRANSFORM = build_data_processor(base_transforms)
TRAIN_TRANSFORM = build_data_processor(base_transforms + [random_left_right_flip])


def compile_dataset(root, index_file, subset):
    train = True if "train" in index_file else False
    data_transform = TRAIN_TRANSFORM if train else DATA_TRANSFORM
    data = MineRLDataset(root=root, sequence_len=SEQ_LENGTH, train=train, data_split=1, prepare=False,
                         experiment="MineRLObtainDiamond-v0",
                         input_space=INPUT_SPACE,
                         action_space=ACTION_SPACE,
                         data_transform=data_transform,
                         frame_skip=FRAME_SKIP,
                         include_metadata=True)

    return data
