# TODO put this into a more appropriate place
# from adaptivelearner import get_r_t
import matplotlib.pyplot as plt
import numpy as np
import torch
import os

from environments.parallel_envs import make_vec_envs
from utils import helpers as utl

device = torch.device("cuda:3" if torch.cuda.is_available() else "cpu")


def evaluate(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             encoder=None,
             num_episodes=None
             ):
    # 5_24 when is this function called?
    env_name = args.env_name
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros(
        (num_processes, num_episodes + 1)).to(device)

    # --- initialise environments and latents ---

    envs = make_vec_envs(env_name,
                         seed=args.seed * 42 + iter_idx,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=device,
                         rank_offset=num_processes + 1,  # to use diff tmp folders than main processes
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1
                         )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, tasks, _ = utl.reset_env(envs, args)

    # this counts how often an agent has done the same tasks already
    task_count = torch.zeros(num_processes).long().to(device)

    if encoder is not None:
        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(
            num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    for episode_idx in range(num_episodes):

        for step_idx in range(num_steps):

            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              tasks=tasks,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=True)

            # observe reward and next obs
            [state, belief, tasks], (rew_raw, rew_normalised), done, infos = utl.env_step(
                envs, action, args)
            done_mdp = [info['done_mdp'] for info in infos]

            if args.enable_adaptivelearner:

                r_t = torch.from_numpy(np.array([info['r_t'] for info in infos], dtype=int)).to(
                    device).float().view((-1, 1))

                if encoder is not None:
                    # update the hidden state
                    latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                  next_obs=state,
                                                                                                  action=action,
                                                                                                  reward=rew_raw,
                                                                                                  done=None,
                                                                                                  hidden_state=hidden_state,
                                                                                                  r_t=r_t)
            else:
                if encoder is not None:
                    # update the hidden state
                    latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                  next_obs=state,
                                                                                                  action=action,
                                                                                                  reward=rew_raw,
                                                                                                  done=None,
                                                                                                  hidden_state=hidden_state)

            # add rewards
            returns_per_episode[range(num_processes),
                                task_count] += rew_raw.view(-1)

            for i in np.argwhere(done_mdp).flatten():
                # count tasks up, but cap at num_episodes + 1
                # zero-indexed, so no +1
                task_count[i] = min(task_count[i] + 1, num_episodes)
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, tasks, _ = utl.reset_env(
                    envs, args, indices=done_indices, state=state)

    envs.close()

    return returns_per_episode[:, :num_episodes]


def visualise_behaviour(args,
                        policy,
                        image_folder,
                        iter_idx,
                        ret_rms,
                        tasks,
                        encoder=None,
                        reward_decoder=None,
                        state_decoder=None,
                        task_decoder=None,
                        compute_rew_reconstruction_loss=None,
                        compute_task_reconstruction_loss=None,
                        compute_state_reconstruction_loss=None,
                        compute_kl_loss=None,
                        ):
    # initialise environment

    task_dir_name = os.path.join(os.path.join(
        args.results_log_dir, 'logs_{}'.format(args.env_name), 'eval_traj_task_{}.npy'.format(args.seed)))

    if iter_idx == -1:
        # make env and sample_traj_task for evaluation
        vis_trainingstep = np.arange(-1, int(
            args.num_frames) // args.policy_num_steps // args.num_processes, args.vis_interval)

        # save the task
        # 暂且认为seed不同，task不同，所以存task的文件名要带上seed

        env = make_vec_envs(env_name=args.env_name,
                            # seed=args.seed * 42 + iter_idx,
                            seed=iter_idx,  # 5_24 not sure setting this seed will get same task within same iter_idx
                            num_processes=1,
                            gamma=args.policy_gamma,
                            device=device,
                            episodes_per_task=args.max_rollouts_per_task,
                            normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms,
                            # not sure if the temp folders would otherwise clash
                            rank_offset=args.num_processes + 42,
                            tasks=tasks,
                            given_task=False,
                            task_dir_name=task_dir_name,
                            vis_training_step=vis_trainingstep,
                            )
        env.close()

    print('########### now evaluating !!!!!!!!!!')
    env = make_vec_envs(env_name=args.env_name,
                        # seed=args.seed * 42 + iter_idx,
                        seed=iter_idx,  # 5_24 not sure setting this seed will get same task within same iter_idx
                        num_processes=1,
                        gamma=args.policy_gamma,
                        device=device,
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms,
                        # not sure if the temp folders would otherwise clash
                        rank_offset=args.num_processes + 42,
                        tasks=tasks,
                        iter_idx=iter_idx,
                        given_task=True,
                        task_dir_name=task_dir_name,
                        vis_training_step=None)

    # 5_13 update plot_vae_loss for changing tasks within one episode
    # 在 plot_vae_loss 里是有影响的，但是在 plot_behaviour 里没有影响

    # get a sample rollout
    unwrapped_env = env.venv.unwrapped.envs[0]
    if hasattr(env.venv.unwrapped.envs[0], 'unwrapped'):
        unwrapped_env = unwrapped_env.unwrapped
    if hasattr(unwrapped_env, 'visualise_behaviour'):
        # if possible, get it from the env directly
        # (this might visualise other things in addition)
        traj = unwrapped_env.visualise_behaviour(env=env,
                                                 args=args,
                                                 policy=policy,
                                                 iter_idx=iter_idx,
                                                 encoder=encoder,
                                                 reward_decoder=reward_decoder,
                                                 state_decoder=state_decoder,
                                                 task_decoder=task_decoder,
                                                 image_folder=image_folder,
                                                 )
    else:
        traj = get_test_rollout(args, env, policy, encoder)

    latent_means, latent_logvars, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, episode_tasks = traj

    if latent_means is not None:
        latent_means = [_.cpu() for _ in latent_means]
        latent_logvars = [_.cpu() for _ in latent_logvars]

        plot_latents(latent_means, latent_logvars,
                     image_folder=image_folder,
                     iter_idx=iter_idx
                     )

        # 5_13 skip plot_vae_loss for now
        # if not (args.disable_decoder and args.disable_kl_term):
        #     plot_vae_loss(args,
        #                   latent_means,
        #                   latent_logvars,
        #                   episode_prev_obs,
        #                   episode_next_obs,
        #                   episode_actions,
        #                   episode_rewards,
        #                   episode_tasks,
        #                   image_folder=image_folder,
        #                   iter_idx=iter_idx,
        #                   reward_decoder=reward_decoder,
        #                   state_decoder=state_decoder,
        #                   task_decoder=task_decoder,
        #                   compute_task_reconstruction_loss=compute_task_reconstruction_loss,
        #                   compute_rew_reconstruction_loss=compute_rew_reconstruction_loss,
        #                   compute_state_reconstruction_loss=compute_state_reconstruction_loss,
        #                   compute_kl_loss=compute_kl_loss,
        #                   )

    env.close()
    return episode_rewards


def get_test_rollout(args, env, policy, encoder=None):
    num_episodes = args.max_rollouts_per_task

    # --- initialise things we want to keep track of ---

    episode_prev_obs = [[] for _ in range(num_episodes)]
    episode_next_obs = [[] for _ in range(num_episodes)]
    episode_actions = [[] for _ in range(num_episodes)]
    episode_rewards = [[] for _ in range(num_episodes)]

    episode_returns = []
    episode_lengths = []

    if encoder is not None:
        episode_latent_samples = [[] for _ in range(num_episodes)]
        episode_latent_means = [[] for _ in range(num_episodes)]
        episode_latent_logvars = [[] for _ in range(num_episodes)]
    else:
        curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
        episode_latent_means = episode_latent_logvars = None

    # --- roll out policy ---

    # (re)set environment
    env.reset_task()
    state, belief, tasks, _ = utl.reset_env(env, args)
    state = state.reshape((1, -1)).to(device)
    tasks = tasks.view(-1) if tasks is not None else None

    for episode_idx in range(num_episodes):

        curr_rollout_rew = []

        if encoder is not None:
            if episode_idx == 0:
                # reset to prior
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(
                    1)
                curr_latent_sample = curr_latent_sample[0].to(device)
                curr_latent_mean = curr_latent_mean[0].to(device)
                curr_latent_logvar = curr_latent_logvar[0].to(device)
            episode_latent_samples[episode_idx].append(
                curr_latent_sample[0].clone())
            episode_latent_means[episode_idx].append(
                curr_latent_mean[0].clone())
            episode_latent_logvars[episode_idx].append(
                curr_latent_logvar[0].clone())

        for step_idx in range(1, env._max_episode_steps + 1):

            episode_prev_obs[episode_idx].append(state.clone())

            latent = utl.get_latent_for_policy(args,
                                               latent_sample=curr_latent_sample,
                                               latent_mean=curr_latent_mean,
                                               latent_logvar=curr_latent_logvar)
            _, action = policy.act(
                state=state.view(-1), latent=latent, belief=belief, tasks=tasks, deterministic=True)
            action = action.reshape((1, *action.shape))

            # observe reward and next obs
            (state, belief, tasks), (rew_raw,
                                     rew_normalised), done, infos = utl.env_step(env, action, args)
            state = state.reshape((1, -1)).to(device)
            tasks = tasks.view(-1) if tasks is not None else None

            if encoder is not None:
                # update tasks embedding
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                    action.float().to(device),
                    state,
                    rew_raw.reshape((1, 1)).float().to(device),
                    hidden_state,
                    return_prior=False)

                episode_latent_samples[episode_idx].append(
                    curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(
                    curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(
                    curr_latent_logvar[0].clone())

            episode_next_obs[episode_idx].append(state.clone())
            episode_rewards[episode_idx].append(rew_raw.clone())
            episode_actions[episode_idx].append(action.clone())

            if infos[0]['done_mdp']:
                break

        episode_returns.append(sum(curr_rollout_rew))
        episode_lengths.append(step_idx)

    # clean up
    if encoder is not None:
        episode_latent_means = [torch.stack(e) for e in episode_latent_means]
        episode_latent_logvars = [torch.stack(
            e) for e in episode_latent_logvars]

    episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
    episode_next_obs = [torch.cat(e) for e in episode_next_obs]
    episode_actions = [torch.cat(e) for e in episode_actions]
    episode_rewards = [torch.cat(r) for r in episode_rewards]

    return episode_latent_means, episode_latent_logvars, \
        episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
        episode_returns


def plot_latents(latent_means,
                 latent_logvars,
                 image_folder,
                 iter_idx,
                 ):
    """
    Plot mean/variance over time
    """

    plt.figure(figsize=(18, 5))
    plt.subplot(1, 2, 1)
    plt.plot(np.stack(latent_means))
    plt.title('latent mean')
    plt.subplot(1, 2, 2)
    plt.plot(np.exp(np.stack(latent_logvars)))
    plt.title('latent var')
    plt.tight_layout()
    if image_folder is not None:
        plt.savefig('{}/{}_latents'.format(image_folder, iter_idx))
        plt.close('all')
    else:
        plt.show()

    # num_rollouts = len(latent_means)
    # num_episode_steps = len(latent_means[0])

    # latent_means = torch.cat(latent_means).cpu().detach().numpy()
    # latent_logvars = torch.cat(latent_logvars).cpu().detach().numpy()

    # plt.figure(figsize=(12, 5))

    # # plt.subplot(1, 2, 1)
    # # plt.plot(range(latent_means.shape[0]), latent_means, '-', alpha=0.5)
    # # plt.plot(range(latent_means.shape[0]), latent_means.mean(axis=1), 'k-')
    # # for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
    # #     span = latent_means.max() - latent_means.min()
    # #     plt.plot([tj + 0.5, tj + 0.5],
    # #              [latent_means.min() - span * 0.05, latent_means.max() + span * 0.05],
    # #              'k--', alpha=0.5)
    # # plt.xlabel('env steps', fontsize=15)
    # # plt.ylabel('latent mean', fontsize=15)

    # # plt.subplot(1, 2, 2)
    # # latent_var = np.exp(latent_logvars)
    # # plt.plot(range(latent_logvars.shape[0]), latent_var, '-', alpha=0.5)
    # # plt.plot(range(latent_logvars.shape[0]), latent_var.mean(axis=1), 'k-')
    # # for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
    # #     span = latent_var.max() - latent_var.min()
    # #     plt.plot([tj + 0.5, tj + 0.5],
    # #              [latent_var.min() - span * 0.05, latent_var.max() + span * 0.05],
    # #              'k--', alpha=0.5)
    # # plt.xlabel('env steps', fontsize=15)
    # # plt.ylabel('latent variance', fontsize=15)

    # latent = np.concatenate([latent_means, latent_logvars], axis=1)
    # plt.plot(range(latent.shape[0]), latent, '-', alpha=0.5)
    # for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
    #     span = latent_means.max() - latent_means.min()
    #     plt.plot([tj + 0.5, tj + 0.5],
    #              [latent_means.min() - span * 0.05, latent_means.max() + span * 0.05],
    #              'k--', alpha=0.5)
    # plt.xlabel('env steps', fontsize=15)
    # plt.ylabel('latent', fontsize=15)

    # plt.tight_layout()
    # if image_folder is not None:
    #     plt.savefig('{}/{}_latents'.format(image_folder, iter_idx))
    #     plt.close('all')
    # else:
    #     plt.show()


def plot_vae_loss(args,
                  latent_means,
                  latent_logvars,
                  prev_obs,
                  next_obs,
                  actions,
                  rewards,
                  tasks,
                  image_folder,
                  iter_idx,
                  reward_decoder,
                  state_decoder,
                  task_decoder,
                  compute_task_reconstruction_loss,
                  compute_rew_reconstruction_loss,
                  compute_state_reconstruction_loss,
                  compute_kl_loss
                  ):
    num_rollouts = len(latent_means)
    num_episode_steps = len(latent_means[0])
    if not args.disable_stochasticity_in_latent:
        num_samples = 10  # how many samples to use to get an average/std ELBO loss
    else:
        num_samples = 1

    latent_means = torch.cat(latent_means)
    latent_logvars = torch.cat(latent_logvars)

    prev_obs = torch.cat(prev_obs).to(device)
    next_obs = torch.cat(next_obs).to(device)
    actions = torch.cat(actions).to(device)
    rewards = torch.cat(rewards).to(device)
    tasks = torch.cat(tasks).to(device)

    if actions.dim() == 3:
        actions = actions.squeeze(1)

    # - we will try to make predictions for each tuple in trajectory, hence we need to expand the targets
    prev_obs = prev_obs.unsqueeze(0).expand(
        num_samples, *prev_obs.shape).to(device)
    next_obs = next_obs.unsqueeze(0).expand(
        num_samples, *next_obs.shape).to(device)
    actions = actions.unsqueeze(0).expand(
        num_samples, *actions.shape).to(device)
    rewards = rewards.unsqueeze(0).expand(
        num_samples, *rewards.shape).to(device)
    tasks = tasks.unsqueeze(0).expand(
        num_samples, *tasks.shape).to(device)

    rew_reconstr_mean = []
    rew_reconstr_std = []
    rew_pred_std = []

    state_reconstr_mean = []
    state_reconstr_std = []
    state_pred_std = []

    task_reconstr_mean = []
    task_reconstr_std = []
    task_pred_std = []

    # compute the sum of ELBO_t's by looping through (trajectory length + prior)
    for i in range(len(latent_means)):  # iterate for each episode, 1 in our case

        curr_latent_mean = latent_means[i]
        curr_latent_logvar = latent_logvars[i]

        # compute the reconstruction loss
        if not args.disable_stochasticity_in_latent:
            # take several samples from the latent distribution
            latent_samples = utl.sample_gaussian(
                curr_latent_mean.view(-1), curr_latent_logvar.view(-1), num_samples)
        else:
            latent_samples = torch.cat(
                (curr_latent_mean.view(-1), curr_latent_logvar.view(-1))).unsqueeze(0)

        # expand: each latent sample will be used to make predictions for the entire trajectory
        len_traj = prev_obs.shape[1]

        # compute reconstruction losses
        if task_decoder is not None:
            loss_task, task_pred = compute_task_reconstruction_loss(
                latent_samples, tasks, return_predictions=True)

            # average/std across the different samples
            task_reconstr_mean.append(loss_task.mean())
            task_reconstr_std.append(loss_task.std())
            task_pred_std.append(task_pred.std())

        latent_samples = latent_samples.unsqueeze(1).expand(
            num_samples, len_traj, latent_samples.shape[-1])

        if reward_decoder is not None:
            loss_rew, rew_pred = compute_rew_reconstruction_loss(latent_samples, prev_obs, next_obs,
                                                                 actions, rewards, return_predictions=True)
            # sum along length of trajectory
            loss_rew = loss_rew.sum(dim=1)
            rew_pred = rew_pred.sum(dim=1)

            # average/std across the different samples
            rew_reconstr_mean.append(loss_rew.mean())
            rew_reconstr_std.append(loss_rew.std())
            rew_pred_std.append(rew_pred.std())

        if state_decoder is not None:
            loss_state, state_pred = compute_state_reconstruction_loss(latent_samples, prev_obs, next_obs,
                                                                       actions, return_predictions=True)
            # sum along length of trajectory
            loss_state = loss_state.sum(dim=1)
            state_pred = state_pred.sum(dim=1)

            # average/std across the different samples
            state_reconstr_mean.append(loss_state.mean())
            state_reconstr_std.append(loss_state.std())
            state_pred_std.append(state_pred.std())

    # kl term
    vae_kl_term = compute_kl_loss(latent_means, latent_logvars, None)

    # --- plot KL term ---

    x = range(len(vae_kl_term))

    plt.plot(x, vae_kl_term.cpu().detach().numpy(), 'b-')
    vae_kl_term = vae_kl_term.cpu()
    for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
        span = vae_kl_term.max() - vae_kl_term.min()
        plt.plot([tj + 0.5, tj + 0.5],
                 [vae_kl_term.min() - span * 0.05, vae_kl_term.max() + span * 0.05],
                 'k--', alpha=0.5)
    plt.xlabel('env steps', fontsize=15)
    plt.ylabel('KL term', fontsize=15)
    plt.tight_layout()
    if image_folder is not None:
        plt.savefig('{}/{}_kl'.format(image_folder, iter_idx))
        plt.close('all')
    else:
        plt.show()

    # --- plot rew reconstruction ---

    if reward_decoder is not None:

        rew_reconstr_mean = torch.stack(
            rew_reconstr_mean).detach().cpu().numpy()
        rew_reconstr_std = torch.stack(rew_reconstr_std).detach().cpu().numpy()
        rew_pred_std = torch.stack(rew_pred_std).detach().cpu().numpy()

        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        p = plt.plot(x, rew_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               rew_reconstr_mean - rew_reconstr_std,
                               rew_reconstr_mean + rew_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (rew_reconstr_mean - rew_reconstr_std).min()
            max_y = (rew_reconstr_mean + rew_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('reward reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, rew_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = rew_pred_std.max() - rew_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [rew_pred_std.min() - span * 0.05,
                      rew_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of rew reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig(
                '{}/{}_rew_reconstruction'.format(image_folder, iter_idx))
            plt.close('all')
        else:
            plt.show()

    # --- plot state reconstruction ---

    if state_decoder is not None:

        plt.figure(figsize=(12, 5))

        state_reconstr_mean = torch.stack(
            state_reconstr_mean).detach().cpu().numpy()
        state_reconstr_std = torch.stack(
            state_reconstr_std).detach().cpu().numpy()
        state_pred_std = torch.stack(state_pred_std).detach().cpu().numpy()

        plt.subplot(1, 2, 1)
        p = plt.plot(x, state_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               state_reconstr_mean - state_reconstr_std,
                               state_reconstr_mean + state_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (state_reconstr_mean - state_reconstr_std).min()
            max_y = (state_reconstr_mean + state_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('state reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, state_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = state_pred_std.max() - state_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [state_pred_std.min() - span * 0.05,
                      state_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of state reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig(
                '{}/{}_state_reconstruction'.format(image_folder, iter_idx))
            plt.close('all')
        else:
            plt.show()

    # --- plot tasks reconstruction ---

    if task_decoder is not None:

        plt.figure(figsize=(12, 5))

        task_reconstr_mean = torch.stack(
            task_reconstr_mean).detach().cpu().numpy()
        task_reconstr_std = torch.stack(
            task_reconstr_std).detach().cpu().numpy()
        task_pred_std = torch.stack(task_pred_std).detach().cpu().numpy()

        plt.subplot(1, 2, 1)
        p = plt.plot(x, task_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               task_reconstr_mean - task_reconstr_std,
                               task_reconstr_mean + task_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (task_reconstr_mean - task_reconstr_std).min()
            max_y = (task_reconstr_mean + task_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('tasks reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, task_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = task_pred_std.max() - task_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [task_pred_std.min() - span * 0.05,
                      task_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of tasks reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig(
                '{}/{}_task_reconstruction'.format(image_folder, iter_idx))
            plt.close('all')
        else:
            plt.show()
