import torch

from train.envs.minerl_env import make_minerl
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_treechop_binact_softcam_48_aug0_fskip2 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCamSepVal import Network as net
from torch.utils.data._utils.collate import default_collate
from train.behavioral_cloning.datasets.experience import Experience

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def test_craftequip():
    """
    Test if playback works.
    """

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=None,  # here, you give the above initialized experience
        checkpoint=None, # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=False,   # this needs to be none when playing back an experience
        experience_folder=None,
        craft_equip=True
    )

    state = env.reset()

    # do some stuff in the environment.
    for i in range(100):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = dataset.ACTION_SPACE.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            dataset.ACTION_SPACE.evaluate_actions(out_dict, actions)

        # take env step
        actions[0]['nearbyCraft'] = 'wooden_pickaxe'
        actions[0]['craft'] = 'none'
        state, reward, done, info = env.step(actions)

        actions[0]['craft'] = 'stick'
        actions[0]['nearbyCraft'] = 'none'
        state, reward, done, info = env.step(actions)

    env.close()
    print('done')


def test_craft_equip_with_detseed():
    """
    Test if playback works.
    """

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    experience_file = 'train/deterministic_envs/recordings/diamond3.p'
    experience_recording = Experience.load(experience_file)

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=experience_recording,  # here, you give the above initialized experience
        checkpoint='craft_stick',  # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=False,  # this needs to be none when playing back an experience
        craft_equip=True
    )

    state = env.reset()

    action = env.action_space.noop()

    obs, reward, done, info = env.step((action,))

    # now do a nearby craft wooden pickaxe
    # take env step
    action['nearbyCraft'] = 'wooden_pickaxe'

    obs, reward, done, info = env.step((action,))

    env.close()
    print('done')


if __name__ == '__main__':
    #test_craftequip()
    test_craft_equip_with_detseed()
