import torch
import numpy as np
import os
import datetime

from train.envs.minerl_env import make_minerl
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_treechop_binact_softcam_48_aug0_fskip2 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCamSepVal import Network as net
from torch.utils.data._utils.collate import default_collate
from train.behavioral_cloning.datasets.experience import Experience


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def test_experience_playback(folder):
    """
    Test if playback works.
    """

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=None,  # here, you give the above initialized experience
        checkpoint=None, # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=False,   # this needs to be none when playing back an experience
        experience_folder=folder
    )

    state = env.reset()

    env.close()
    print('done')


def test_record_playback_and_extend_experience(folder):
    """
    It's important that an existing experience recording can be extended. For this to work, we will need both the
    experience replay and recording wrapper.
    """

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=None,  # here, you give the above initialized experience
        checkpoint=None, # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=True,   # this needs to be none when playing back an experience
        experience_folder=folder
    )

    state = env.reset()

    current_experience = env.env_method("get_experience")

    # do some stuff in the environment.

    print('do random stuff now')
    for i in range(100):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = dataset.ACTION_SPACE.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            dataset.ACTION_SPACE.evaluate_actions(out_dict, actions)

        # take env step
        state, reward, done, info = env.step(actions)

    # if you are at a point where you are certain that you want save a trajectory, you can do so with the
    # .set_checkpoint(checkpoint : str) and .save(folder : str) methods
    for idx, experience in enumerate(current_experience):
        # set checkpoint
        experience.set_checkpoint(checkpoint='after_random_move')
        experience.save(
            'tmp/new_experiences/log-log-log-log-planks-planks-stick-stick-crafting_table-random-{}.p'.format(
                idx, experience.meta_info, datetime.datetime.now())
        )

    env.close()
    print('done')


if __name__ == "__main__":
    test_experience_playback(folder=None)


