import torch
import numpy as np
import os

from train.envs.minerl_env import make_minerl
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_treechop_binact_softcam_48_aug0_fskip2 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCamSepVal import Network as net
from torch.utils.data._utils.collate import default_collate
from train.behavioral_cloning.datasets.experience import Experience


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def test_experience_recorder():
    """
    Test if recorder works.
    """
    initial_seed  = 42
    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=initial_seed,
        experience_recording=None,  # when just recording you need to have a None here
        checkpoint=None,
        frame_skip=None,
        make_new_recording=True  # this needs to be set to true if you want the experience recorder to record.
    )

    state = env.reset()

    # this is the crucial statement to get access to the Experience object through the ThreadedVecEnv. For this to work,
    # the ExposeRecordedExperienceWrapper in train/envs/minerl_env.py must be the very last wrapper.
    # It returns a LIST of experiences (one for each env you're running, so be careful which one you save)
    current_experience = env.env_method('get_experience')

    # do some stuff in the environment.
    for i in range(10):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = dataset.ACTION_SPACE.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            dataset.ACTION_SPACE.evaluate_actions(out_dict, actions)

        # take env step
        state, reward, done, info = env.step(actions)

    # if you are at a point where you are certain that you want save a trajectory, you can do so with the
    # .set_checkpoint(checkpoint : str) and .save(folder : str) methods
    for idx, experience in enumerate(current_experience):
        # set checkpoint
        experience.set_checkpoint(checkpoint='start')
        experience.save('tmp/experiences/test_experience_recorder-env_{}.p'.format(
            idx, experience.meta_info))

    env.close()
    print('done')


def test_experience_playback():
    """
    Test if playback works.
    """
    experience_file = 'train/deterministic_envs/recordings/diamond4.p'
    experience_recording = Experience.load(experience_file)

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=experience_recording,  # here, you give the above initialized experience
        checkpoint=next(reversed(experience_recording.checkpoints)), # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=False   # this needs to be none when playing back an experience
    )

    state = env.reset()

    env.close()
    print('done')


def test_experience_playback_with_latest_checkpoint():
    """
    Test if playback works when giving 'latest' as checkpoint
    """
    experience_file = 'tmp/experiences/test_experience_recorder-env_0.p'
    experience_recording = Experience.load(experience_file)

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=experience_recording,  # here, you give the above initialized experience
        checkpoint='latest', # if you want to use the last checkpoint, use this
        frame_skip=None,
        make_new_recording=False   # this needs to be none when playing back an experience
    )

    state = env.reset()

    env.close()
    print('done')


def test_record_playback_and_extend_experience():
    """
    It's important that an existing experience recording can be extended. For this to work, we will need both the
    experience replay and recording wrapper.
    """
    initial_seed = 42
    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()
    cloud = False

    experience_file = 'tmp/experiences/metacontroller/seed_7952619-task_log-milestone_0_of_2-checkpoints_-log'
    experience_recording = Experience.load(experience_file)

    model.to(DEVICE)

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,
        experience_recording=experience_recording,
        checkpoint='latest',
        frame_skip=None,
        make_new_recording=True
    )

    state = env.reset()

    current_experience = env.env_method('get_experience')

    # do some stuff in the environment.
    for i in range(100):
        # collate, to tensor, to device
        input_dict = default_collate(state)
        for k in input_dict.keys():
            input_dict[k] = input_dict[k].to(DEVICE)

        # predict next action
        out_dict = model.forward(input_dict)

        # translate predictions to environment actions
        actions = dataset.ACTION_SPACE.logits_to_dict(env.action_space.noop(), out_dict)

        # evaluate model
        value, action_log_probs, camera_log_probs, action_entropy, camera_entropy = \
            dataset.ACTION_SPACE.evaluate_actions(out_dict, actions)

        # take env step
        state, reward, done, info = env.step(actions)

    # if you are at a point where you are certain that you want save a trajectory, you can do so with the
    # .set_checkpoint(checkpoint : str) and .save(folder : str) methods
    for idx, experience in enumerate(current_experience):
        # set checkpoint
        experience.set_checkpoint(checkpoint='after_random_move')
        experience.save('tmp/experiences/test_experience_recorder-env_{}.p'.format(
            idx, experience.meta_info))

    print('done')


def test_setting_experiences():
    experience_file = 'tmp/experiences/test_experience_recorder-env_0.p'
    experience_recording = Experience.load(experience_file)

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=experience_recording,  # here, you give the above initialized experience
        checkpoint='latest',
        frame_skip=None,
        make_new_recording=False  # this needs to be none when playing back an experience
    )

    state = env.reset()

    # now, set a new experience to play back
    new_experience = Experience.load('train/deterministic_envs/recordings/diamond3.p')
    env.env_method('set_experience', new_experience)
    state = env.reset()

    env.close()
    print('done')


if __name__ == "__main__":
    # comment in a specific test case
    #test_experience_recorder()
    test_experience_playback()
    #test_experience_playback_with_latest_checkpoint()
    #test_record_playback_and_extend_experience()
    #test_experience_playback_with_latest_checkpoint()
    #test_setting_experiences()

