import matplotlib.pyplot as plt
import numpy as np
import torch
import cv2
import os

from environments.parallel_envs import make_vec_envs
from utils import helpers as utl

from utils.helpers import get_device


# Evaluate for pretrainers, meaning we directly sample the latents from the policy
def evaluate_pretrainer(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             test,
             encoder=None,
             num_episodes=None
             ):
    env_name = args.env_name
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())
    successes = torch.zeros((num_processes, num_episodes + 1)).to(get_device())

    # --- initialise environments and latents ---

    envs = make_vec_envs(env_name,
                         seed=args.seed * 42 + iter_idx,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,  # to use diff tmp folders than main processes
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         test=test,
                         )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, task = utl.reset_env(envs, args)

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(get_device())

    for episode_idx in range(num_episodes):
        # TODO: need to change the logic here (or just create a new fn) if latents need to change each step
        latent_sample, latent_mean, latent_logvar = policy.sample_latent(num_processes=num_processes)
        
        for step_idx in range(num_steps):

            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=None,
                                              latent_pol=None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(envs, action, args)
            done_mdp = [info['done_mdp'] for info in infos]
            # Check if 'success' exists in info, otherwise use 0
            success = torch.tensor([info.get('success', 0) for info in infos], device=get_device())
            
            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)
            successes[range(num_processes), task_count] += success
            
            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1, num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

    #cap successes to 1
    successes[successes >= 1.0] = 1.0
    
    infos = None # placeholder for extra information
    
    envs.close()

    return returns_per_episode[:, :num_episodes], successes[:, :num_episodes], infos

# Evaluate for varibad, meaning we don't need to worry about mixtures
def evaluate_varibad(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             test,
             encoder=None,
             num_episodes=None
             ):
    env_name = args.env_name
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())
    successes = torch.zeros((num_processes, num_episodes + 1)).to(get_device())

    # --- initialise environments and latents ---

    envs = make_vec_envs(env_name,
                         seed=args.seed * 42 + iter_idx,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,  # to use diff tmp folders than main processes
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         test=test,
                         )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, task = utl.reset_env(envs, args)

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(get_device())

    if encoder is not None:
        # reset latent state to prior
        latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    for episode_idx in range(num_episodes):

        for step_idx in range(num_steps):

            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=None,
                                              latent_pol=None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = utl.env_step(envs, action, args)
            done_mdp = [info['done_mdp'] for info in infos]
            # Check if 'success' exists in info, otherwise use 0
            success = torch.tensor([info.get('success', 0) for info in infos], device=get_device())

            if encoder is not None:
                # update the hidden state
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                    action.float().to(get_device()),
                    state,
                    rew_raw.reshape((1, 1)).float().to(get_device()),
                    hidden_state,
                    return_prior=False)

                latent_sample = curr_latent_sample[0].clone()
                latent_mean = curr_latent_mean[0].clone()
                latent_logvar = curr_latent_logvar[0].clone()

            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)
            successes[range(num_processes), task_count] += success
            
            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1, num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

    #cap successes to 1
    successes[successes >= 1.0] = 1.0
    
    infos = None # placeholder for extra information
    
    envs.close()

    return returns_per_episode[:, :num_episodes], successes[:, :num_episodes], infos

# Designed to just work for a single-task environment (not ML10 or ML45)
def evaluate_SDVT(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             test,
             encoder=None,
             encoder_pol=None,
             num_episodes=None,
             task_list = None, save_episode_probs=False, save_episode_successes=False):
    env_name = args.env_name
    policy_separate_gru = args.policy_separate_gru
    if hasattr(args, 'test_env_name'):
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = args.num_processes

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())
    successes = torch.zeros((num_processes, num_episodes + 1)).to(get_device())

    # --- initialise environments and latents ---
    envs = make_vec_envs(env_name,
                         seed=args.seed * 42 + iter_idx,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,  # to use diff tmp folders than main processes
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         test=test,
                         )
    num_steps = envs._max_episode_steps


    # reset environments
    state, belief, task = utl.reset_env(envs, args)
    #envs.reset_task(task_list)
    #state2= torch.from_numpy(envs._get_obs()).float().to(get_device())
    #state[:,:39] = state2.clone() #varibad wrapper has one more 'done' dimension

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(get_device())
    prob = None
    episode_probs = None
    episode_successes = None
    if save_episode_successes:
        episode_successes = np.zeros((num_processes, num_episodes))

    if encoder is not None:
        # reset latent state to prior
        if args.vae_mixture_num>1:
            prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state = encoder.prior_mixture(num_processes)
                
            latent_sample = prior_z
            latent_mean = prior_mu
            latent_logvar = torch.log(prior_var + 1e-20)
            hidden_state = prior_hidden_state
            prob = prior_prob
            if save_episode_probs:
                episode_probs = torch.zeros((num_processes, num_episodes + 1, num_steps, args.vae_mixture_num)).to(get_device())
        else:
            latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(num_processes)
            w = None
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = w = None

    if policy_separate_gru:
        latent_pol, hidden_state_pol = encoder_pol.prior(num_processes)
    else:
        latent_pol = hidden_state_pol = None

    #meta_successes = np.zeros(num_processes)

    env_step = utl.env_step

    for episode_idx in range(num_episodes):
        #successes = np.zeros(num_processes)
        # TODO: make this work
        # if args.render: #save videos for each rollout episode
        #     for i in range(num_processes):
        #         print(episode_idx, i)
        #         if args.load_iter is None:
        #             os.makedirs(
        #                 'renders/{}/task{:2d}/subtask{:2d}'.format(iter_idx, task_list[i, 0],
        #                                                                                 task_list[i, 1]), exist_ok=True)
        #         else:
        #             os.makedirs(
        #                 'renders/' + args.load_dir + '{}/task{:2d}/subtask{:2d}'.format(args.load_iter, task_list[i, 0],
        #                                                                                 task_list[i, 1]), exist_ok=True)
        #     imgs_array = []

        for step_idx in range(num_steps):
            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=prob,
                                              latent_pol=latent_pol if policy_separate_gru else None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = env_step(envs, action, args)

            done_mdp = [info['done_mdp'] for info in infos]
            success = torch.tensor([info.get('success', 0) for info in infos], device=get_device())

            # TODO: make this work later
            # if args.render:
            #     imgs = np.array([info['image'] for info in infos])
            #     success_list = [info['success'] for info in infos]
            #     for i in range(num_processes):
            #         imgs[i,0:30,0:30,:] = int(255*success_list[i]) #to distinguish success
            #     imgs_array.append(imgs)

            with torch.no_grad():
                if encoder is not None:
                    # update the hidden state
                    if args.vae_mixture_num>1:
                        latent_sample, latent_mean, latent_logvar, hidden_state, \
                        y, z, mu, var, logits, prob = utl.update_encoding(encoder=encoder,
                                                                          next_obs=state,
                                                                          action=action,
                                                                          reward=rew_raw,
                                                                          done=None,
                                                                          hidden_state=hidden_state,
                                                                          vae_mixture_num=args.vae_mixture_num)
                        
                        if save_episode_probs:
                            episode_probs[:, episode_idx, step_idx , :] = prob.clone()

                    else:
                        latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                      next_obs=state,
                                                                                                      action=action,
                                                                                                      reward=rew_raw,
                                                                                                      done=None,
                                                                                                      hidden_state=hidden_state)

                    if policy_separate_gru:
                        latent_pol, hidden_state_pol = utl.update_encoding_pol(
                            encoder=encoder_pol,
                            next_obs=state,
                            action=action,
                            reward=rew_raw,
                            done=None,
                            hidden_state=hidden_state_pol,
                        )

            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)
            successes[range(num_processes), task_count] += success

            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1, num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

        # TODO: make this work later
        # if args.render:
        #     imgs_array = np.array(imgs_array)
        #     for i in range(num_processes):
        #         img_array = imgs_array[:,i,:,:,:] #(500x10x480x640x3) to (500x480x640x3)

        #         if args.load_iter is None:
        #             pathout = 'renders/{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}.mp4'.format(iter_idx, task_list[i, 0], task_list[i, 1], task_list[i, 0], task_list[i, 1], episode_idx)
        #         else:
        #             #pathout = 'renders/'+args.load_dir+ '{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}.mp4'.format(args.load_iter,task_list[i,0],task_list[i,1],task_list[i,0],task_list[i,1],episode_idx)
        #             pathout = 'renders/'+args.load_dir+ '{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}_return{:04d}.mp4'.format(args.load_iter,task_list[i,0],task_list[i,1],task_list[i,0],task_list[i,1],episode_idx, round(returns_per_episode.detach().cpu().numpy()[i, task_count[i]-1]))
        #         out = cv2.VideoWriter(pathout, cv2.VideoWriter_fourcc(*'mp4v'), 50, (np.shape(img_array)[2],np.shape(img_array)[1])) #width, height
        #         horizon = np.shape(img_array)[0]
        #         for j in range(horizon):
        #             rgb_img = cv2.cvtColor(img_array[j], cv2.COLOR_RGB2BGR)
        #             out.write(rgb_img)
        #         out.release()

        if save_episode_successes:
            episode_successes[:, episode_idx] = successes
        #meta_successes = meta_successes+successes
    successes[successes >= 1.0] = 1.0
    
    infos = None # placeholder for extra information

    envs.close()
    #meta_successes = meta_successes/num_episodes

    if args.vae_mixture_num > 1:
        prob =prob.detach().cpu().numpy()
        if save_episode_probs:
            episode_probs = episode_probs[:, :num_episodes, :, :].detach().cpu().numpy()
            infos['episode_probs'] = episode_probs
    
    return returns_per_episode[:, :num_episodes], successes[:, :num_episodes], infos


def evaluate_metaworld(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             encoder=None,
             encoder_pol=None,
             num_episodes=None,
             test = False, task_list = None, save_episode_probs=False, save_episode_successes=False):
    env_name = args.env_name
    policy_separate_gru = args.policy_separate_gru
    if test:
        env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = 10

    # --- set up the things we want to log ---

    # for each process, we log the returns during the first, second, ... episode
    # (such that we have a minimum of [num_episodes]; the last column is for
    #  any overflow and will be discarded at the end, because we need to wait until
    #  all processes have at least [num_episodes] many episodes)
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())

    # --- initialise environments and latents ---
    envs = make_vec_envs(env_name,
                         seed=args.seed,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,  # to use diff tmp folders than main processes
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         )
    num_steps = envs._max_episode_steps


    # reset environments
    state, belief, task = utl.reset_env(envs, args)
    envs.reset_task(task_list)
    state2= torch.from_numpy(envs._get_obs()).float().to(get_device())
    state[:,:39] = state2.clone() #varibad wrapper has one more 'done' dimension

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(get_device())
    prob = None
    episode_probs = None
    episode_successes = None
    if save_episode_successes:
        episode_successes = np.zeros((num_processes, num_episodes))

    if encoder is not None:
        # reset latent state to prior
        if args.vae_mixture_num>1:
            prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state = encoder.prior_mixture(num_processes)
                
            latent_sample = prior_z
            latent_mean = prior_mu
            latent_logvar = torch.log(prior_var + 1e-20)
            hidden_state = prior_hidden_state
            prob = prior_prob
            if save_episode_probs:
                episode_probs = torch.zeros((num_processes, num_episodes + 1, num_steps, args.vae_mixture_num)).to(get_device())
        else:
            latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    if policy_separate_gru:
        latent_pol, hidden_state_pol = encoder_pol.prior(num_processes)
    else:
        latent_pol = hidden_state_pol = None

    meta_successes = np.zeros(num_processes)

    env_step = utl.env_step

    for episode_idx in range(num_episodes):
        successes = np.zeros(num_processes)
        if args.render: #save videos for each rollout episode
            for i in range(num_processes):
                print(episode_idx, i)
                if args.load_iter is None:
                    os.makedirs(
                        'renders/{}/task{:2d}/subtask{:2d}'.format(iter_idx, task_list[i, 0],
                                                                                        task_list[i, 1]), exist_ok=True)
                else:
                    os.makedirs(
                        'renders/' + args.load_dir + '{}/task{:2d}/subtask{:2d}'.format(args.load_iter, task_list[i, 0],
                                                                                        task_list[i, 1]), exist_ok=True)
            imgs_array = []
        for step_idx in range(num_steps):
            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=prob,
                                              latent_pol=latent_pol if policy_separate_gru else None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              #deterministic=True,
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = env_step(envs, action, args)

            done_mdp = [info['done_mdp'] for info in infos]
            successes = np.fmax(successes, [info['success'] for info in infos])

            if args.render:
                imgs = np.array([info['image'] for info in infos])
                success_list = [info['success'] for info in infos]
                for i in range(num_processes):
                    imgs[i,0:30,0:30,:] = int(255*success_list[i]) #to distinguish success
                imgs_array.append(imgs)

            with torch.no_grad():
                if encoder is not None:
                    # update the hidden state
                    if args.vae_mixture_num>1:
                        latent_sample, latent_mean, latent_logvar, hidden_state,\
                        y, z, mu, var, logits, prob = utl.update_encoding(encoder=encoder,
                                                                                                      next_obs=state,
                                                                                                      action=action,
                                                                                                      reward=rew_raw,
                                                                                                      done=None,
                                                                                                      hidden_state=hidden_state,
                                                                          vae_mixture_num=args.vae_mixture_num)
                        if save_episode_probs:
                            episode_probs[:, episode_idx, step_idx , :] = prob.clone()

                    else:
                        latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                      next_obs=state,
                                                                                                      action=action,
                                                                                                      reward=rew_raw,
                                                                                                      done=None,
                                                                                                      hidden_state=hidden_state)

                    if policy_separate_gru:
                        latent_pol, hidden_state_pol = utl.update_encoding_pol(
                            encoder=encoder_pol,
                            next_obs=state,
                            action=action,
                            reward=rew_raw,
                            done=None,
                            hidden_state=hidden_state_pol,
                        )

            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)

            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1, num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

        if args.render:
            imgs_array = np.array(imgs_array)
            for i in range(num_processes):
                img_array = imgs_array[:,i,:,:,:] #(500x10x480x640x3) to (500x480x640x3)

                if args.load_iter is None:
                    pathout = 'renders/{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}.mp4'.format(iter_idx, task_list[i, 0], task_list[i, 1], task_list[i, 0], task_list[i, 1], episode_idx)
                else:
                    #pathout = 'renders/'+args.load_dir+ '{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}.mp4'.format(args.load_iter,task_list[i,0],task_list[i,1],task_list[i,0],task_list[i,1],episode_idx)
                    pathout = 'renders/'+args.load_dir+ '{}/task{:2d}/subtask{:2d}/task{:2d}_subtask{:2d}_epi_{:2d}_return{:04d}.mp4'.format(args.load_iter,task_list[i,0],task_list[i,1],task_list[i,0],task_list[i,1],episode_idx, round(returns_per_episode.detach().cpu().numpy()[i, task_count[i]-1]))
                out = cv2.VideoWriter(pathout, cv2.VideoWriter_fourcc(*'mp4v'), 50, (np.shape(img_array)[2],np.shape(img_array)[1])) #width, height
                horizon = np.shape(img_array)[0]
                for j in range(horizon):
                    rgb_img = cv2.cvtColor(img_array[j], cv2.COLOR_RGB2BGR)
                    out.write(rgb_img)
                out.release()

        if save_episode_successes:
            episode_successes[:, episode_idx] = successes
        meta_successes = meta_successes+successes
    envs.close()
    meta_successes = meta_successes/num_episodes

    if args.vae_mixture_num > 1:
        prob =prob.detach().cpu().numpy()
        if save_episode_probs:
            episode_probs = episode_probs[:, :num_episodes, :, :].detach().cpu().numpy()
    return returns_per_episode[:, :num_episodes].detach().cpu().numpy(), latent_mean.detach().cpu().numpy(), \
                                                         latent_logvar.detach().cpu().numpy(), meta_successes, prob, episode_probs, episode_successes


import matplotlib.pyplot as plt
import numpy as np
import torch

from environments.parallel_envs import make_vec_envs
from utils import helpers as utl

def visualise_behaviour(args,
                        policy,
                        image_folder,
                        iter_idx,
                        ret_rms,
                        tasks,
                        num_episodes=None,
                        encoder=None,
                        reward_decoder=None,
                        state_decoder=None,
                        task_decoder=None,
                        compute_rew_reconstruction_loss=None,
                        compute_task_reconstruction_loss=None,
                        compute_state_reconstruction_loss=None,
                        compute_kl_loss=None,
                        is_pretrain=False,
                        ):
    # initialise environment
    env = make_vec_envs(env_name=args.env_name,
                        seed=args.seed * 42 + iter_idx,
                        num_processes=1,
                        gamma=args.policy_gamma,
                        device=get_device(),
                        episodes_per_task=args.max_rollouts_per_task,
                        normalise_rew=args.norm_rew_for_policy, ret_rms=ret_rms,
                        rank_offset=args.num_processes + 42,  # not sure if the temp folders would otherwise clash
                        tasks=tasks
                        )
    episode_task = torch.from_numpy(np.array(env.get_task())).to(get_device()).float()

    # get a sample rollout
    unwrapped_env = env.venv.unwrapped.envs[0]
    if hasattr(env.venv.unwrapped.envs[0], 'unwrapped'):
        unwrapped_env = unwrapped_env.unwrapped
    if hasattr(unwrapped_env, 'visualise_behaviour'):
        # if possible, get it from the env directly
        # (this might visualise other things in addition)
        traj = unwrapped_env.visualise_behaviour(env=env,
                                                 args=args,
                                                 policy=policy,
                                                 iter_idx=iter_idx,
                                                 encoder=encoder,
                                                 reward_decoder=reward_decoder,
                                                 state_decoder=state_decoder,
                                                 task_decoder=task_decoder,
                                                 image_folder=image_folder,
                                                 num_episodes=num_episodes,
                                                 )
    else:
        traj = get_test_rollout(args, env, policy, encoder)

    latent_means, latent_logvars, episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, episode_returns = traj

    if latent_means is not None and not is_pretrain:
        plot_latents(latent_means, latent_logvars,
                     image_folder=image_folder,
                     iter_idx=iter_idx
                     )

        if not (args.disable_decoder and args.disable_kl_term):
            plot_vae_loss(args,
                          latent_means,
                          latent_logvars,
                          episode_prev_obs,
                          episode_next_obs,
                          episode_actions,
                          episode_rewards,
                          episode_task,
                          image_folder=image_folder,
                          iter_idx=iter_idx,
                          reward_decoder=reward_decoder,
                          state_decoder=state_decoder,
                          task_decoder=task_decoder,
                          compute_task_reconstruction_loss=compute_task_reconstruction_loss,
                          compute_rew_reconstruction_loss=compute_rew_reconstruction_loss,
                          compute_state_reconstruction_loss=compute_state_reconstruction_loss,
                          compute_kl_loss=compute_kl_loss,
                          )

    env.close()


def get_test_rollout(args, env, policy, encoder=None):
    pretrain_exp_labels = ['diayn', 'csd', 'diayn-cat'] # keeping track of these at the top for ximplicity

    if not ((args.exp_label in ['VariBAD', 'ppo', 'sac', 'sac-v2']+pretrain_exp_labels) or ('varibad' in args.exp_label)):
        raise NotImplementedError('method {} not implemented'.format(args.exp_label))
    
    num_episodes = args.max_rollouts_per_task

    # --- initialise things we want to keep track of ---
    if args.exp_label in pretrain_exp_labels:
        eval_latent_samples = policy.sample_latent_eval(num_episodes)
        num_episodes = eval_latent_samples.shape[0]

    episode_prev_obs = [[] for _ in range(num_episodes)]
    episode_next_obs = [[] for _ in range(num_episodes)]
    episode_actions = [[] for _ in range(num_episodes)]
    episode_rewards = [[] for _ in range(num_episodes)]

    episode_returns = []
    episode_lengths = []

    if encoder is not None:
        episode_latent_samples = [[] for _ in range(num_episodes)]
        episode_latent_means = [[] for _ in range(num_episodes)]
        episode_latent_logvars = [[] for _ in range(num_episodes)]
    elif args.exp_label in pretrain_exp_labels:
        episode_latent_samples = [[] for _ in range(num_episodes)]
        episode_latent_means = [[] for _ in range(num_episodes)]
        episode_latent_logvars = [[] for _ in range(num_episodes)]
    else:
        curr_latent_sample = curr_latent_mean = curr_latent_logvar = None
        episode_latent_means = episode_latent_logvars = None

    # --- roll out policy ---

    # (re)set environment
    env.reset_task()
    state, belief, task = utl.reset_env(env, args)
    state = state.reshape((1, -1)).to(get_device())
    task = task.view(-1) if task is not None else None

    for episode_idx in range(num_episodes):

        curr_rollout_rew = []

        if encoder is not None:
            if episode_idx == 0:
                # reset to prior
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder.prior(1)
                curr_latent_sample = curr_latent_sample[0].to(get_device())
                curr_latent_mean = curr_latent_mean[0].to(get_device())
                curr_latent_logvar = curr_latent_logvar[0].to(get_device())
            
            episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
            episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
            episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

        if args.exp_label in pretrain_exp_labels:
            curr_latent_sample = eval_latent_samples[episode_idx]
            curr_latent_mean = torch.zeros_like(eval_latent_samples)
            curr_latent_logvar = torch.zeros_like(eval_latent_samples)
            episode_latent_samples[episode_idx].append(curr_latent_sample.clone())
            episode_latent_means[episode_idx].append(curr_latent_mean.clone())
            episode_latent_logvars[episode_idx].append(curr_latent_logvar.clone())

        for step_idx in range(1, env._max_episode_steps + 1):

            episode_prev_obs[episode_idx].append(state.clone())

            latent = utl.get_latent_for_policy(args,
                                               latent_sample=curr_latent_sample,
                                               latent_mean=curr_latent_mean,
                                               latent_logvar=curr_latent_logvar)
            action = policy.act(state=state.view(-1), latent=latent, belief=belief, task=task, 
                                   prob=False, latent_pol=False, deterministic=True)
            # make sure action is correct
            if isinstance(action, list) or isinstance(action, tuple):
                value, action = action
            else:
                value = None
                
            if args.exp_label not in pretrain_exp_labels: #hacky?
                action = action.reshape((1, *action.shape))

            # observe reward and next obs
            (state, belief, task), (rew_raw, rew_normalised), done, infos = utl.env_step(env, action, args)
            state = state.reshape((1, -1)).to(get_device())
            task = task.view(-1) if task is not None else None

            if encoder is not None:
                # update task embedding
                curr_latent_sample, curr_latent_mean, curr_latent_logvar, hidden_state = encoder(
                    action.float().to(get_device()),
                    state,
                    rew_raw.reshape((1, 1)).float().to(get_device()),
                    hidden_state,
                    return_prior=False)

                episode_latent_samples[episode_idx].append(curr_latent_sample[0].clone())
                episode_latent_means[episode_idx].append(curr_latent_mean[0].clone())
                episode_latent_logvars[episode_idx].append(curr_latent_logvar[0].clone())

            episode_next_obs[episode_idx].append(state.clone())
            episode_rewards[episode_idx].append(rew_raw.clone())
            episode_actions[episode_idx].append(action.clone())

            if infos[0]['done_mdp'] and not done and args.exp_label not in pretrain_exp_labels:
                start_obs_raw = infos[0]['start_state']
                start_obs_raw = torch.from_numpy(start_obs_raw).float().reshape((1, -1)).to(get_device())
                start_pos = start_obs_raw
                break

        episode_returns.append(sum(curr_rollout_rew))
        episode_lengths.append(step_idx)

    # clean up
    if (encoder is not None) or (args.exp_label in pretrain_exp_labels):
        episode_latent_means = [torch.stack(e) for e in episode_latent_means]
        episode_latent_logvars = [torch.stack(e) for e in episode_latent_logvars]

    episode_prev_obs = [torch.cat(e) for e in episode_prev_obs]
    episode_next_obs = [torch.cat(e) for e in episode_next_obs]
    episode_actions = [torch.cat(e) for e in episode_actions]
    episode_rewards = [torch.cat(r) for r in episode_rewards]

    return episode_latent_means, episode_latent_logvars, \
           episode_prev_obs, episode_next_obs, episode_actions, episode_rewards, \
           episode_returns


def plot_latents(latent_means,
                 latent_logvars,
                 image_folder,
                 iter_idx,
                 ):
    """
    Plot mean/variance over time
    """

    num_rollouts = len(latent_means)
    num_episode_steps = len(latent_means[0])

    latent_means = torch.cat(latent_means).cpu().detach().numpy()
    latent_logvars = torch.cat(latent_logvars).cpu().detach().numpy()

    plt.figure(figsize=(12, 5))

    plt.subplot(1, 2, 1)
    plt.plot(range(latent_means.shape[0]), latent_means, '-', alpha=0.5)
    plt.plot(range(latent_means.shape[0]), latent_means.mean(axis=1), 'k-')
    for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
        span = latent_means.max() - latent_means.min()
        plt.plot([tj + 0.5, tj + 0.5],
                 [latent_means.min() - span * 0.05, latent_means.max() + span * 0.05],
                 'k--', alpha=0.5)
    plt.xlabel('env steps', fontsize=15)
    plt.ylabel('latent mean', fontsize=15)

    plt.subplot(1, 2, 2)
    latent_var = np.exp(latent_logvars)
    plt.plot(range(latent_logvars.shape[0]), latent_var, '-', alpha=0.5)
    plt.plot(range(latent_logvars.shape[0]), latent_var.mean(axis=1), 'k-')
    for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
        span = latent_var.max() - latent_var.min()
        plt.plot([tj + 0.5, tj + 0.5],
                 [latent_var.min() - span * 0.05, latent_var.max() + span * 0.05],
                 'k--', alpha=0.5)
    plt.xlabel('env steps', fontsize=15)
    plt.ylabel('latent variance', fontsize=15)

    plt.tight_layout()
    if image_folder is not None:
        plt.savefig('{}/{}_latents'.format(image_folder, iter_idx))
        plt.close()
    else:
        plt.show()


def plot_vae_loss(args,
                  latent_means,
                  latent_logvars,
                  prev_obs,
                  next_obs,
                  actions,
                  rewards,
                  task,
                  image_folder,
                  iter_idx,
                  reward_decoder,
                  state_decoder,
                  task_decoder,
                  compute_task_reconstruction_loss,
                  compute_rew_reconstruction_loss,
                  compute_state_reconstruction_loss,
                  compute_kl_loss
                  ):
    num_rollouts = len(latent_means)
    num_episode_steps = len(latent_means[0])
    if not args.disable_stochasticity_in_latent:
        num_samples = 10  # how many samples to use to get an average/std ELBO loss
    else:
        num_samples = 1

    latent_means = torch.cat(latent_means)
    latent_logvars = torch.cat(latent_logvars)

    prev_obs = torch.cat(prev_obs).to(get_device())
    next_obs = torch.cat(next_obs).to(get_device())
    actions = torch.cat(actions).to(get_device())
    rewards = torch.cat(rewards).to(get_device())

    # - we will try to make predictions for each tuple in trajectory, hence we need to expand the targets
    prev_obs = prev_obs.unsqueeze(0).expand(num_samples, *prev_obs.shape).to(get_device())
    next_obs = next_obs.unsqueeze(0).expand(num_samples, *next_obs.shape).to(get_device())
    actions = actions.unsqueeze(0).expand(num_samples, *actions.shape).to(get_device())
    rewards = rewards.unsqueeze(0).expand(num_samples, *rewards.shape).to(get_device())

    rew_reconstr_mean = []
    rew_reconstr_std = []
    rew_pred_std = []

    state_reconstr_mean = []
    state_reconstr_std = []
    state_pred_std = []

    task_reconstr_mean = []
    task_reconstr_std = []
    task_pred_std = []

    # compute the sum of ELBO_t's by looping through (trajectory length + prior)
    for i in range(len(latent_means)):

        curr_latent_mean = latent_means[i]
        curr_latent_logvar = latent_logvars[i]

        # compute the reconstruction loss
        if not args.disable_stochasticity_in_latent:
            # take several samples from the latent distribution
            latent_samples = utl.sample_gaussian(curr_latent_mean.view(-1), curr_latent_logvar.view(-1), num_samples)
        else:
            latent_samples = torch.cat((curr_latent_mean.view(-1), curr_latent_logvar.view(-1))).unsqueeze(0)

        # expand: each latent sample will be used to make predictions for the entire trajectory
        len_traj = prev_obs.shape[1]

        # compute reconstruction losses
        if task_decoder is not None:
            loss_task, task_pred = compute_task_reconstruction_loss(latent_samples, task, return_predictions=True)

            # average/std across the different samples
            task_reconstr_mean.append(loss_task.mean())
            task_reconstr_std.append(loss_task.std())
            task_pred_std.append(task_pred.std())

        latent_samples = latent_samples.unsqueeze(1).expand(num_samples, len_traj, latent_samples.shape[-1])

        if reward_decoder is not None:
            loss_rew, rew_pred = compute_rew_reconstruction_loss(latent_samples, prev_obs, next_obs,
                                                                 actions, rewards, return_predictions=True)
            # sum along length of trajectory
            loss_rew = loss_rew.sum(dim=1)
            rew_pred = rew_pred.sum(dim=1)

            # average/std across the different samples
            rew_reconstr_mean.append(loss_rew.mean())
            rew_reconstr_std.append(loss_rew.std())
            rew_pred_std.append(rew_pred.std())

        if state_decoder is not None:
            loss_state, state_pred = compute_state_reconstruction_loss(latent_samples, prev_obs, next_obs,
                                                                       actions, return_predictions=True)
            # sum along length of trajectory
            loss_state = loss_state.sum(dim=1)
            state_pred = state_pred.sum(dim=1)

            # average/std across the different samples
            state_reconstr_mean.append(loss_state.mean())
            state_reconstr_std.append(loss_state.std())
            state_pred_std.append(state_pred.std())

    # kl term
    if compute_kl_loss is not None:
        vae_kl_term = compute_kl_loss(latent_means, latent_logvars, None)

    # --- plot KL term ---

        x = range(len(vae_kl_term))

        plt.plot(x, vae_kl_term.cpu().detach().numpy(), 'b-')
        vae_kl_term = vae_kl_term.cpu()
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = vae_kl_term.max() - vae_kl_term.min()
            plt.plot([tj + 0.5, tj + 0.5],
                    [vae_kl_term.min() - span * 0.05, vae_kl_term.max() + span * 0.05],
                    'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('KL term', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_kl'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()

    # --- plot rew reconstruction ---

    if reward_decoder is not None:

        rew_reconstr_mean = torch.stack(rew_reconstr_mean).detach().cpu().numpy()
        rew_reconstr_std = torch.stack(rew_reconstr_std).detach().cpu().numpy()
        rew_pred_std = torch.stack(rew_pred_std).detach().cpu().numpy()

        plt.figure(figsize=(12, 5))
        plt.subplot(1, 2, 1)
        p = plt.plot(x, rew_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               rew_reconstr_mean - rew_reconstr_std,
                               rew_reconstr_mean + rew_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (rew_reconstr_mean - rew_reconstr_std).min()
            max_y = (rew_reconstr_mean + rew_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('reward reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, rew_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = rew_pred_std.max() - rew_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [rew_pred_std.min() - span * 0.05, rew_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of rew reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_rew_reconstruction'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()

    # --- plot state reconstruction ---

    if state_decoder is not None:

        plt.figure(figsize=(12, 5))

        state_reconstr_mean = torch.stack(state_reconstr_mean).detach().cpu().numpy()
        state_reconstr_std = torch.stack(state_reconstr_std).detach().cpu().numpy()
        state_pred_std = torch.stack(state_pred_std).detach().cpu().numpy()

        plt.subplot(1, 2, 1)
        p = plt.plot(x, state_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               state_reconstr_mean - state_reconstr_std,
                               state_reconstr_mean + state_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (state_reconstr_mean - state_reconstr_std).min()
            max_y = (state_reconstr_mean + state_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('state reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, state_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = state_pred_std.max() - state_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [state_pred_std.min() - span * 0.05, state_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of state reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_state_reconstruction'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()

    # --- plot task reconstruction ---

    if task_decoder is not None:

        plt.figure(figsize=(12, 5))

        task_reconstr_mean = torch.stack(task_reconstr_mean).detach().cpu().numpy()
        task_reconstr_std = torch.stack(task_reconstr_std).detach().cpu().numpy()
        task_pred_std = torch.stack(task_pred_std).detach().cpu().numpy()

        plt.subplot(1, 2, 1)
        p = plt.plot(x, task_reconstr_mean, 'b-')
        plt.gca().fill_between(x,
                               task_reconstr_mean - task_reconstr_std,
                               task_reconstr_mean + task_reconstr_std,
                               facecolor=p[0].get_color(), alpha=0.1)
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            min_y = (task_reconstr_mean - task_reconstr_std).min()
            max_y = (task_reconstr_mean + task_reconstr_std).max()
            span = max_y - min_y
            plt.plot([tj + 0.5, tj + 0.5],
                     [min_y - span * 0.05, max_y + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('task reconstruction error', fontsize=15)

        plt.subplot(1, 2, 2)
        plt.plot(x, task_pred_std, 'b-')
        for tj in np.cumsum([0, *[num_episode_steps for _ in range(num_rollouts)]]):
            span = task_pred_std.max() - task_pred_std.min()
            plt.plot([tj + 0.5, tj + 0.5],
                     [task_pred_std.min() - span * 0.05, task_pred_std.max() + span * 0.05],
                     'k--', alpha=0.5)
        plt.xlabel('env steps', fontsize=15)
        plt.ylabel('std of task reconstruction', fontsize=15)
        plt.tight_layout()
        if image_folder is not None:
            plt.savefig('{}/{}_task_reconstruction'.format(image_folder, iter_idx))
            plt.close()
        else:
            plt.show()


def evaluate_DME(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             test,
             encoder=None,
             encoder_pol=None,
             num_episodes=None,
             task_list = None, save_episode_probs=False, save_episode_successes=False):
    """
    DME-specific evaluation function for SDVT that handles the w parameter
    and extended encoder outputs.
    """
    env_step = utl.env_step
    env_name = args.env_name
    policy_separate_gru = args.policy_separate_gru
    if test:
        if hasattr(args, 'test_env_name'):
            env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = 10

    # --- set up the things we want to log ---
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())
    successes = torch.zeros((num_processes, num_episodes + 1)).to(get_device())
    episode_probs = None
    episode_successes = None
    # --- initialise environments and latents ---
    envs = make_vec_envs(env_name,
                         seed=args.seed * 42 + iter_idx,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         )
    num_steps = envs._max_episode_steps

    # Initialize episode tracking arrays after envs is created
    if save_episode_probs:
        episode_probs = torch.zeros((num_processes, num_episodes, num_steps, args.vae_mixture_num)).to(get_device())
    if save_episode_successes:
        episode_successes = torch.zeros((num_processes, num_episodes)).to(get_device())

    # reset environments
    state, belief, task = utl.reset_env(envs, args)
    if task_list is not None:
        envs.reset_task(task_list)
        state2 = torch.from_numpy(envs._get_obs()).float().to(get_device())
        state[:,:39] = state2.clone()

    task_count = torch.zeros(num_processes).long().to(get_device())
    prob = None
    w = None  # Initialize w parameter for GMVAE

    if encoder is not None:
        # reset latent state to prior - handle GMVAE extended output
        if args.vae_mixture_num > 1:
            prior_output = encoder.prior_mixture(num_processes)
            # GMVAE returns 8 values: prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, w
            prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, w = prior_output
                
            latent_sample = prior_z
            latent_mean = prior_mu
            latent_logvar = torch.log(prior_var + 1e-20)
            hidden_state = prior_hidden_state
            prob = prior_prob
        else:
            latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(num_processes)

    if policy_separate_gru:
        latent_pol, hidden_state_pol = encoder_pol.prior(num_processes)

    for episode_idx in range(num_episodes):
        for step_idx in range(num_steps):
            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=prob,
                                              latent_pol=latent_pol if policy_separate_gru else None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              w=w,  # Pass w parameter for GMVAE
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = env_step(envs, action, args)

            done_mdp = [info['done_mdp'] for info in infos]
            success = torch.tensor([info.get('success', 0) for info in infos], device=get_device())

            with torch.no_grad():
                if encoder is not None:
                    # update the hidden state - handle GMVAE extended output
                    if args.vae_mixture_num > 1:
                        encoder_output = utl.update_encoding(encoder=encoder,
                                                            next_obs=state,
                                                            action=action,
                                                            reward=rew_raw,
                                                            done=None,
                                                            hidden_state=hidden_state,
                                                            vae_mixture_num=args.vae_mixture_num)
                        
                        # GMVAE returns 7 values: latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w
                        latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w = encoder_output

                    else:
                        latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                      next_obs=state,
                                                                                                      action=action,
                                                                                                      reward=rew_raw,
                                                                                                      done=None,
                                                                                                      hidden_state=hidden_state)

                    if policy_separate_gru:
                        latent_pol, hidden_state_pol = utl.update_encoding_pol(
                            encoder=encoder_pol,
                            next_obs=state,
                            action=action,
                            reward=rew_raw,
                            done=None,
                            hidden_state=hidden_state_pol,
                        )

            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)
            successes[range(num_processes), task_count] += success

            for i in np.argwhere(done_mdp).flatten():
                task_count[i] = min(task_count[i] + 1, num_episodes)
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

        if save_episode_successes:
            episode_successes[:, episode_idx] = successes
    
    successes[successes >= 1.0] = 1.0
    infos = None

    envs.close()

    if args.vae_mixture_num > 1:
        prob = prob.detach().cpu().numpy()
        if save_episode_probs:
            episode_probs = episode_probs[:, :num_episodes, :, :].detach().cpu().numpy()
            infos = {'episode_probs': episode_probs}
    
    return returns_per_episode[:, :num_episodes], successes[:, :num_episodes], infos


def evaluate_metaworld_dme(args,
             policy,
             ret_rms,
             iter_idx,
             tasks,
             encoder=None,
             encoder_pol=None,
             num_episodes=None,
             test = False, task_list = None, save_episode_probs=False, save_episode_successes=False):
    """
    DME-specific evaluation function for MetaWorld that handles the w parameter
    and extended encoder outputs. Returns same format as evaluate_metaworld.
    """
    env_step = utl.env_step
    env_name = args.env_name
    policy_separate_gru = args.policy_separate_gru
    if test:
        if hasattr(args, 'test_env_name'):
            env_name = args.test_env_name
    if num_episodes is None:
        num_episodes = args.max_rollouts_per_task
    num_processes = 10

    # --- set up the things we want to log ---
    returns_per_episode = torch.zeros((num_processes, num_episodes + 1)).to(get_device())

    # --- initialise environments and latents ---
    envs = make_vec_envs(env_name,
                         seed=args.seed,
                         num_processes=num_processes,
                         gamma=args.policy_gamma,
                         device=get_device(),
                         rank_offset=num_processes + 1,
                         episodes_per_task=num_episodes,
                         normalise_rew=args.norm_rew_for_policy,
                         ret_rms=ret_rms,
                         tasks=tasks,
                         add_done_info=args.max_rollouts_per_task > 1,
                         )
    num_steps = envs._max_episode_steps

    # reset environments
    state, belief, task = utl.reset_env(envs, args)
    envs.reset_task(task_list)
    state2= torch.from_numpy(envs._get_obs()).float().to(get_device())
    state[:,:39] = state2.clone()

    # this counts how often an agent has done the same task already
    task_count = torch.zeros(num_processes).long().to(get_device())
    prob = None
    episode_probs = None
    episode_successes = None
    w = None  # Initialize w parameter for GMVAE
    
    if save_episode_successes:
        episode_successes = np.zeros((num_processes, num_episodes))

    if encoder is not None:
        # reset latent state to prior - handle GMVAE extended output
        if args.vae_mixture_num > 1:
            prior_output = encoder.prior_mixture(num_processes)
            # GMVAE returns 8 values: prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, prior_w
            prior_y, prior_z, prior_mu, prior_var, prior_logits, prior_prob, prior_hidden_state, prior_w = prior_output
                
            latent_sample = prior_z
            latent_mean = prior_mu
            latent_logvar = torch.log(prior_var + 1e-20)
            hidden_state = prior_hidden_state
            prob = prior_prob
            w = prior_w  # Set w from prior for GMVAE
            if save_episode_probs:
                episode_probs = torch.zeros((num_processes, num_episodes + 1, num_steps, args.vae_mixture_num)).to(get_device())
        else:
            latent_sample, latent_mean, latent_logvar, hidden_state = encoder.prior(num_processes)
    else:
        latent_sample = latent_mean = latent_logvar = hidden_state = None

    if policy_separate_gru:
        latent_pol, hidden_state_pol = encoder_pol.prior(num_processes)
    else:
        latent_pol = hidden_state_pol = None

    meta_successes = np.zeros(num_processes)

    for episode_idx in range(num_episodes):
        successes = np.zeros(num_processes)
        for step_idx in range(num_steps):
            with torch.no_grad():
                _, action = utl.select_action(args=args,
                                              policy=policy,
                                              state=state,
                                              belief=belief,
                                              task=task,
                                              prob=prob,
                                              latent_pol=latent_pol if policy_separate_gru else None,
                                              latent_sample=latent_sample,
                                              latent_mean=latent_mean,
                                              latent_logvar=latent_logvar,
                                              w=w,  # Keep w parameter for GMVAE functionality
                                              deterministic=False
                                              )

            # observe reward and next obs
            [state, belief, task], (rew_raw, rew_normalised), done, infos = env_step(envs, action, args)

            done_mdp = [info['done_mdp'] for info in infos]
            successes = np.fmax(successes, [info['success'] for info in infos])

            with torch.no_grad():
                if encoder is not None:
                    # update the hidden state - handle GMVAE extended output
                    if args.vae_mixture_num > 1:
                        encoder_output = utl.update_encoding_dme(encoder=encoder,
                                                            next_obs=state,
                                                            action=action,
                                                            reward=rew_raw,
                                                            done=None,
                                                            hidden_state=hidden_state)
                        
                        # GMVAE returns 7 values: latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w
                        latent_sample, latent_mean, latent_logvar, hidden_state, y, prob, w = encoder_output
                        
                        if save_episode_probs:
                            episode_probs[:, episode_idx, step_idx , :] = prob.clone()

                    else:
                        latent_sample, latent_mean, latent_logvar, hidden_state = utl.update_encoding(encoder=encoder,
                                                                                                      next_obs=state,
                                                                                                      action=action,
                                                                                                      reward=rew_raw,
                                                                                                      done=None,
                                                                                                      hidden_state=hidden_state)

                    if policy_separate_gru:
                        latent_pol, hidden_state_pol = utl.update_encoding_pol(
                            encoder=encoder_pol,
                            next_obs=state,
                            action=action,
                            reward=rew_raw,
                            done=None,
                            hidden_state=hidden_state_pol,
                        )

            # add rewards
            returns_per_episode[range(num_processes), task_count] += rew_raw.view(-1)

            for i in np.argwhere(done_mdp).flatten():
                # count task up, but cap at num_episodes + 1
                task_count[i] = min(task_count[i] + 1, num_episodes)  # zero-indexed, so no +1
            if np.sum(done) > 0:
                done_indices = np.argwhere(done.flatten()).flatten()
                state, belief, task = utl.reset_env(envs, args, indices=done_indices, state=state)

        if save_episode_successes:
            episode_successes[:, episode_idx] = successes
        meta_successes = meta_successes + successes
    
    envs.close()
    meta_successes = meta_successes / num_episodes

    if args.vae_mixture_num > 1:
        prob = prob.detach().cpu().numpy()
        if save_episode_probs:
            episode_probs = episode_probs[:, :num_episodes, :, :].detach().cpu().numpy()
    
    # Return same format as evaluate_metaworld: returns, latent_mean, latent_logvar, successes, prob, episode_probs, episode_successes
    # Plus w as an 8th return value for GMVAE
    return returns_per_episode[:, :num_episodes].detach().cpu().numpy(), latent_mean.detach().cpu().numpy(), \
           latent_logvar.detach().cpu().numpy(), meta_successes, prob, episode_probs, episode_successes, \
           w.detach().cpu().numpy() if w is not None else np.zeros((num_processes, args.latent_dim))


