import torch

from train.behavioral_cloning.datasets.minerl_dataset import MineRLDataset
from train.envs.minerl_env import make_minerl
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_diamond_binact_softcam_48_aug0 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCam import Network as net
from torch.utils.data._utils.collate import default_collate


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


if __name__ == "__main__":
    """
    Test build if minerl recorder works.
    """
    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net().to(DEVICE)
    env = make_minerl(env_id=env_id,
                      n_cpu=n_cpu,
                      seq_len=dataset.SEQ_LENGTH,
                      transforms=dataset.DATA_TRANSFORM,
                      input_space=dataset.INPUT_SPACE)

    lessons_buffer = MineRLDataset(root=None, sequence_len=dataset.SEQ_LENGTH, train=True, prepare=False,
                                   experiment="MineRLObtainDiamond-v0",
                                   input_space=dataset.INPUT_SPACE,
                                   action_space=dataset.ACTION_SPACE,
                                   data_transform=dataset.DATA_TRANSFORM)

    episode = None

    state = env.reset()
    for i in range(30):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = dataset.ACTION_SPACE.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            dataset.ACTION_SPACE.evaluate_actions(out_dict, actions)

        # take env step
        state, reward, done, info = env.step(actions)
        episode = info[0]['episode']

    sequence = episode.to_minerl_pkl_format()
    lessons_buffer.add_sequence(sequence, insert_index=0)
    lessons_buffer.dump(0, 'demo')
    print('done')

    env.close()

