from train.common.config import Config
import torch
from train.reinforcment_learning.runner import Runner
import numpy as np

from train.behavioral_cloning.datasets.ds import ds_s32_orig_minerl_diamond_binact_softcam_48_aug0 as dataset
from train.behavioral_cloning.models.LightCnnLstmBinActsSoftCam import Network as net
from train.envs.minerl_env import make_minerl
import imageio
import time, os
from PIL import Image, ImageDraw, ImageFont, ImageOps


def get_device(gpu, **kwargs):
    gpu = "cuda:" + gpu
    return torch.device(gpu if torch.cuda.is_available() else 'cpu')


def save_video(config, runner):

    nsteps = 1200
    with torch.no_grad():
        b_states, b_rewards, b_dones, b_actions, b_values, b_logprobs, b_last_values, epinfo = runner.rollout(nsteps)

    # convert rest of data to numpy arrays
    pov, bin_actions, camera_actions = runner.transform_state_dict(b_states)
    b_actions = runner.transform_action_dict(b_actions)
    b_values = np.stack(b_values)

    # shape = B, T, S, C ,H , W
    # T S C H W
    for j in range(config.num_env):
        frames = []
        for i in range(nsteps):
            state_ = pov[i][j][-1]
            # create a list of frames
            frames.append((state_.transpose(1, 2, 0) * 255).astype(np.uint8))
        # convert list of frames
        cum_reward = np.sum(b_rewards[j])
        frames_ = []
        for i, frame in enumerate(frames):
            img = Image.fromarray(frame, 'RGB')
            draw = ImageDraw.Draw(img)
            font = ImageFont.load_default()
            draw.text((0, 24), "Val: {}".format(b_values[i][j]), (255, 255, 255), font=font)
            frames_.append(img)

        frames_ = np.stack(frames_)
        timestamp = time.strftime("%Y%m%d%H%M%S")
        video_file = "recording-{}-rew{}-{}.mp4".format(config.env, cum_reward, timestamp)
        video_path = os.path.join(config.record_dir, video_file)
        kargs = {'macro_block_size': None}
        imageio.mimwrite(video_path, frames_, fps=20, **kargs)


if __name__ == "__main__":
    """ main """

    args = Config()
    n_cpu = args.num_env
    DEVICE = get_device(args.gpu)
    load_model = args.load_model
    action_space = dataset.ACTION_SPACE
    input_space = dataset.INPUT_SPACE
    # get the model
    model = net().to(DEVICE)
    path_to_model_parameters = args.model_path
    # load model parameters
    print("Loading pretrained cloned model!")
    model.load_state_dict(torch.load(path_to_model_parameters))

    env = make_minerl("MineRLTreechop-v0", n_cpu, seq_len=dataset.SEQ_LENGTH,
                      transforms=dataset.DATA_TRANSFORM, input_space=input_space, env_server=True)

    gamma = args.gamma
    lam = args.lam
    eps_clip = args.eps_clip
    pg_coef = args.pg_coef
    vf_coef = args.vf_coef
    ent_coef_bin = args.ent_coef_bin
    ent_coef_cam = args.ent_coef_cam
    lrs_actor = args.lrs_actor
    lrs_critic = args.lrs_critic
    betas = (0.9, 0.999)
    nsteps=args.nsteps

    runner = Runner(model=model, nenv=n_cpu,
                    env=env, nsteps=nsteps, seq_len=dataset.SEQ_LENGTH, gamma=gamma,
                    lam=lam, episode_length=args.episode_length, reward_scale=args.reward_scale, device=DEVICE)

    save_video(args, runner)
