from train.common.config import Config
import torch
from train.reinforcment_learning.runner import Runner
import numpy as np

from train.envs.minerl_env import make_minerl
import imageio
import time, os
from PIL import Image, ImageDraw, ImageFont, ImageOps
from train.reinforcment_learning.rudder.rudder import RR_LSTM_ARCH, RR_LSTM
from train.reinforcment_learning.rudder.buffer import LessonBuffer
from train.reinforcment_learning.utils.utils import safemean, get_device, import_module
from train.reinforcment_learning.rudder.rudder_test import get_episode_from_traj, break_sequence, pad_sequence

def rudder_save_video(config, runner, rudder):

    nsteps = 1200
    with torch.no_grad():
        b_states, b_rewards, b_dones, b_actions, b_values, b_logprobs, b_last_values, epinfo = runner.rollout(nsteps)

    # convert rest of data to numpy arrays
    pov, bin_actions, camera_actions = runner.transform_state_dict(b_states)
    b_actions = runner.transform_action_dict(b_actions)
    b_values = np.stack(b_values)

    agent_pov_, agent_bin_actions_, agent_camera_actions_, agent_b_rewards_, agent_lengths_ = get_episode_from_traj(
        b_rewards,
        b_dones, pov,
        bin_actions,
        camera_actions,
        args)

    agent_actions = np.concatenate((agent_bin_actions_, agent_camera_actions_), axis=-1)
    states = torch.tensor(agent_pov_, dtype=torch.float).to(rudder.critic.device).detach()
    actions = torch.tensor(agent_actions, dtype=torch.float).to(rudder.critic.device).detach()
    new_reward=rudder.redistribute_reward(states, actions)
    print(new_reward.shape, print(agent_pov_.shape), b_values.shape)

    # shape = B, T, S, C ,H , W
    # T S C H W
    for j in range(agent_pov_.shape[0]):
        frames = []
        for i in range(args.rudder_episode_len):
            state_ = agent_pov_[j][i]
            # create a list of frames
            frames.append((state_.transpose(1, 2, 0) * 255).astype(np.uint8))
        # convert list of frames
        cum_reward = np.sum(agent_b_rewards_[j])
        frames_ = []
        for i, frame in enumerate(frames):
            img = Image.fromarray(frame, 'RGB')
#            img = ImageOps.invert(img)
            draw = ImageDraw.Draw(img)
            font = ImageFont.load_default()
            rr = 0 if i == (args.rudder_episode_len - 1) else new_reward[j][i]
            draw.text((0, 24), "RR: {}".format(rr, (255, 255, 255), font=font))
            frames_.append(img)

        frames_ = np.stack(frames_)
        timestamp = time.strftime("%Y%m%d%H%M%S")
        video_file = "recording-{}-rew{}-{}.mp4".format(config.env, cum_reward, timestamp)
        video_path = os.path.join(config.record_dir, video_file)
        kargs = {'macro_block_size': None}
        imageio.mimwrite(video_path, frames_, fps=20, **kargs)


if __name__ == "__main__":
    """ main """

    args = Config()
    n_cpu = args.num_env
    DEVICE = get_device(args.gpu)

    Network = import_module(args.model).Network
    dataset = import_module(args.dataset)
    # get behavior policy
    print("Initializing network!")
    model = Network()
    # load pre-trained model parameters
    if args.load_model:
        print("Loading behaviour cloned model!")
        model.load_state_dict(torch.load(args.pre_trained_params))
    model = model.to(DEVICE)
    # get the model
    lessons_buffer = LessonBuffer(3000, 512, (3, 48, 48), 10)
    mean_return = 0.0
    rr_model = RR_LSTM_ARCH(input_lstm=288, lstm_size=128, num_actions=10, duplication=10,
                            bias_mean=mean_return, device=DEVICE)
    rr_model = rr_model.to(DEVICE)

    rudder = RR_LSTM(buffer=lessons_buffer, model=rr_model, num_actions=10)
    path_to_model_parameters = args.rudder_model_path
    # load model parameters
    print("Loading pretrained cloned model!")
    rr_model.load_state_dict(torch.load(path_to_model_parameters))

    env = make_minerl(args.env, n_cpu=args.num_env, seq_len=dataset.SEQ_LENGTH,
                        transforms=dataset.DATA_TRANSFORM, input_space=dataset.INPUT_SPACE,
                        env_server=True)

    runner = Runner(model=model, dataset=dataset, nenv=args.num_env, env=env, nsteps=args.nsteps,
                    seq_len=dataset.SEQ_LENGTH, gamma=args.gamma, lam=args.lam, episode_length=args.episode_length,
                    reward_scale=args.reward_scale, episodic_interactions=True, device=DEVICE)

    rudder_save_video(args, runner, rudder=rudder)
