import torch
import os

from train.envs.minerl_env import make_minerl
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_treechop_binact_softcam_48_aug0_fskip2 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCamSepVal import Network as net
from train.behavioral_cloning.datasets.experience import Experience

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def play_experience_in_folder(folder):
    experience_files = os.listdir(folder)

    experience_file = os.path.join(folder, experience_files[0])
    experience_recording = Experience.load(experience_file)

    n_cpu = 1
    env_id = 'MineRLObtainDiamond-v0'
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,  # seed needs to be None when playing back a recording
        experience_recording=experience_recording,  # here, you give the above initialized experience
        checkpoint='latest',
        frame_skip=None,
        make_new_recording=False  # this needs to be none when playing back an experience
    )

    state = env.reset()

    # now, set a new experience to play back
    for experience_file in experience_files[1:]:
        new_experience = Experience.load(os.path.join(folder, experience_file))
        print('running experince file', new_experience.meta_info)
        env.env_method('set_experience', new_experience)
        state = env.reset()

    env.close()
    print('done')


if __name__ == '__main__':
    play_experience_in_folder('tmp/new_experiences')