import torch
import argparse
from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_treechop_binact_softcam_48_aug0_fskip2 as dataset
from train.behavioral_cloning.datasets.experience import Experience
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCamSepVal import Network as net
from train.envs.minerl_env import make_minerl

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def investigate_experience_playback_folder(exp_folder, generate_video, debug_video, replay_until, n_experiences_to_test, n_cpu, env_id, port):
    model = net()

    model.to(DEVICE)
    print('creating env')

    env = make_minerl(
        env_id=env_id,
        n_cpu=n_cpu,
        seq_len=dataset.SEQ_LENGTH,
        transforms=dataset.DATA_TRANSFORM,
        input_space=dataset.INPUT_SPACE,
        env_server=True,
        seed=None,
        port=port,
        experience_recording=None,
        replay_until=replay_until,
        checkpoint=None,
        frame_skip=None,
        make_new_recording=False,
        experience_folder=exp_folder,
        craft_equip=False,
        generate_experience_video=generate_video,
        debug_video=debug_video
    )

    for i in range(n_experiences_to_test // n_cpu):
        state = env.reset()

    env.close()
    print('done')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_folder', type=str, required=True, help='Folder pointing to experience')
    parser.add_argument('--debug_video', type=str, default='env', help='Generate video of experience')
    parser.add_argument('--generate_video', default=False, help='Generate debug video of experience', action='store_true')
    parser.add_argument('--replay_until', type=str, default='all', help='Consensus until to test')
    parser.add_argument('--n_exp', type=int, default=2, help='Number of experiences to test')
    parser.add_argument('--port', type=int, default=9992, help='Port of the server')
    parser.add_argument('--env', type=str, default='MineRLObtainDiamondDense-v0', help='Environment of server')
    args, _ = parser.parse_known_args()
    return args


def main():
    args = parse_args()
    investigate_experience_playback_folder(
        exp_folder=args.exp_folder,
        generate_video=args.generate_video,
        debug_video=args.debug_video,
        replay_until=args.replay_until,
        n_experiences_to_test=args.n_exp,
        n_cpu=1,
        env_id=args.env,
        port=args.port
    )


if __name__ == '__main__':
    main()
