from collections import defaultdict, deque
import gc
from itertools import chain
import os
from pprint import pprint
import time
import random

import imageio
import numpy as np
import psutil
import torch
import wandb
from debug import debug_print

from onpolicy.utils.util import update_linear_schedule
from onpolicy.runner.shared.base_runner import Runner
from torch.utils.data import DataLoader

import jax
import jax.numpy as jnp
import flashbax as fbx
from flashbax.vault import Vault
from tqdm import tqdm
from onpolicy.algorithms.diffusion_ac.datasets import Football_Dataset
def _t2n(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()



def compute_returns(buffer, next_value, value_normalizer=None):
    """
    Compute returns either as discounted sum of rewards, or using GAE.
    :param next_value: (np.ndarray) value predictions for the step after the last episode step.
    :param value_normalizer: (PopArt) If not None, PopArt value normalizer instance.
    """
    rewards = np.concatenate([buffer.rewards], axis=-1)#[..., np.newaxis].transpose(0, 3, 1, 2, 4)
    # debug_print('fa', rewards.shape, buffer.rewards.shape)
    value_preds = buffer.value_preds.copy()
    # debug_print(next_value.shape, value_preds.shape, buffer.returns.shape)
    # debug_print(buffer.returns)
    # debug_print(value_preds.shape, next_value.shape)
    value_preds[-1] = next_value
    # value_preds = value_preds.transpose(0, 3, 1, 2)[..., np.newaxis]
    masks = np.concatenate([buffer.masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    bad_masks = np.concatenate([buffer.bad_masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    returns = buffer.returns.copy()#.transpose(0, 3, 1, 2)[..., np.newaxis]
    # debug_print(buffer.returns.shape, masks.shape, bad_masks.shape)

    # rewards = np.concatenate(rewards)
    # # debug_print(value_preds.shape, next_value.shape)
    # value_preds = np.concatenate(value_preds)
    # # debug_print(value_preds.shape, next_value.shape)
    # masks = np.concatenate(masks)
    # bad_masks = np.concatenate(bad_masks)
    # returns = np.concatenate(returns)

    if buffer._use_proper_time_limits:
        if buffer._use_gae:
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + value_preds[step]
    else:
        if buffer._use_gae:
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)[..., 0, None]
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    returns[step] = gae + value_preds[step]

    # debug_print(returns.shape)
    returns = returns.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)
    value_preds = value_preds.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)

    assert (returns.shape == buffer.returns.shape), (returns.shape, buffer.returns.shape)
    assert (value_preds.shape == buffer.value_preds.shape), (value_preds.shape, buffer.value_preds.shape)

    buffer.returns = returns
    buffer.value_preds = value_preds

def scheduler(ratio, eta):
    return eta
    # return ratio
    # return ratio * 0.04
    # return min(ratio, 0.03)
    # config = [
    #     (-0.01, 0.00),
    #     (0.1, 0.03),
    #     (0.8, 0.03),
    #     # (0.1, 0.05),
    #     # (0.4, 0.1),
    #     (1.1, 0.06)
    # ]
    # for idx, (pos, height) in enumerate(config):
    #     if ratio > pos:
    #         return height + (config[idx+1][1] - height) / (config[idx+1][0] - pos) * (ratio - pos)
    # return ratio

class FootballRunner(Runner):
    def __init__(self, config):
        super(FootballRunner, self).__init__(config)
        self.env_infos = defaultdict(list)
        self.obs_mean = 0
        self.obs_std = 1
    
    def get_action_distribution(self, step):
        actor = self.trainer.policy
        T = 20
        # debug_print('ha', step, self.buffer.obs[step][0, None, ...].shape, np.repeat(self.buffer.obs[step][0, None, ...], T, axis=0).shape)
        action_distribution = self.trainer.policy.get_actions(np.repeat(self.buffer.share_obs[step][0], T, axis=0),
                            np.repeat(self.buffer.obs[step][0], T, axis=0),
                            np.repeat(self.buffer.rnn_states[step][0], T, axis=0),
                            np.repeat(self.buffer.rnn_states_critic[step][0], T, axis=0),
                            np.repeat(self.buffer.masks[step][0], T, axis=0))[1][:, -1]
        # lf = action_distribution[:, 0:6, np.newaxis]
        # rt = action_distribution[:, np.newaxis, 6:12]
        action_distribution = action_distribution.mean(0, keepdim=True)
        T=1
        # debug_print(action_distribution.shape, np.repeat(self.buffer.available_actions[step][0], T, axis=0).shape)
        debug_print(action_distribution.reshape(T, self.rnum_agents, -1))
        return action_distribution.reshape(T, self.rnum_agents, -1)
        action_distribution = action_distribution.cpu().exp()
        # debug_print(action_distribution/action_distribution.sum(-1, keepdim=True))
        action_distribution = (action_distribution*np.repeat(self.buffer.available_actions[step][0], T, axis=0)).reshape(T, self.rnum_agents, -1)
        return  action_distribution/action_distribution.sum(-1, keepdim=True)
       
    def clone(self):
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        clone_episodes = self.all_args.clone_episodes
        # debug_print(self.log_interval)
        train_dataset = Football_Dataset(self.all_args.vault_uid)
        train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)
        
        # debug_print(offline_data['obs'].shape)
        # print(f"Episode return: {offline_data['reward'].mean()*self.episode_length}")

        # buffer_sample = jax.jit(buffer.sample)
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        # seed = 0
        
        for episode in range(clone_episodes):
            losses = []
            pbar = tqdm(train_loader)
            for batch in pbar:
                obs_batch, actions_batch, rewards_batch, dones_batch = batch
                batch_size = obs_batch.shape[0]
                # self.trainer.policy.actor.diffusion.update_eta(min((episode / clone_episodes) * 2, 1) * self.all_args.eta)
                # obs_batch = (obs_batch - self.obs_mean) / (self.obs_std + 1e-6)
                obs_batch = obs_batch.reshape(batch_size, -1)
                # print(actions_batch.dtype, self.buffer.act_size)
                
                actions_batch_onehot = np.array(jax.nn.one_hot(np.array(actions_batch.cpu()), num_classes=self.buffer.act_size))
                actions_batch = np.array(actions_batch_onehot.reshape(batch_size, -1))
                # debug_print(dones_batch.shape, np.array(dones_batch).shape)
                
                dones = torch.all(dones_batch, axis=-1)
                mask_batch = torch.ones((batch_size, 1), dtype=torch.float32)
                mask_batch[dones == True] = torch.zeros(((dones == True).sum(), 1), dtype=torch.float32)
                data = obs_batch, actions_batch, mask_batch, None, None
                bc_loss, actor_grad_norm = self.trainer.bc_clone(data)
                losses.append(bc_loss.item())
                pbar.set_postfix(loss=np.mean(losses[-50:]))
            self.buffer.obs[0][0] = obs_batch[0].cpu().numpy()
            if episode % 1 == 0:
                print(np.mean(losses), actor_grad_norm, episode)
                # print(self.all_args.eta)
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
            

    def run(self):
        torch.set_printoptions(sci_mode=False, precision=3)
        np.set_printoptions(suppress=True, precision=3)
        # debug_print(self.all_args.eta)
        


        if self.all_args.eta >= -1e-6:
            self.trainer.policy.actor.diffusion.update_eta(self.all_args.eta) #episode / clone_episodes)
        
        if self.all_args.clone_episodes > 0:
            self.clone()
        gc.collect()
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
            
        self.warmup()
        
        start_ppo = 5
        start_critic = 4
        
        clone_episodes = self.all_args.clone_episodes
        if clone_episodes > 0:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 3e-2, 3e-2, total_iters=100)
            lrscheduler.step()
        else:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 1, 1, total_iters=1000)
            start_ppo = 0
            start_critic = 0

        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads
        act_step = self.all_args.act_step
        
        

        for episode in range(episodes):

                tot_time = 0
                start_time = time.time()
                
                if self.use_linear_lr_decay:
                    self.trainer.policy.lr_decay(episode, episodes)
                
                # self.trainer.policy.actor.diffusion.update_eta(episode / episodes)

                for step in range(0, self.episode_length, act_step):
                    # Sample actions for next act_step steps at once
                    latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise = \
                        self.collect_actions(step)
                    
                    latent_actions = latent_actions.reshape(*latent_actions.shape[:-1], act_step, -1)
                    noise = noise.reshape(*noise.shape[:-1], act_step, -1)
                    # debug_print(actions_env.shape, act_step)

                    # Execute actions one by one and collect critic values for each step
                    for i in range(min(act_step, self.episode_length - step)):
                        curr_step = step + i
                        
                        tot_time -= time.time()
                        
                        # Execute single action
                        obs, rewards, dones, infos = self.envs.step(actions_env[:, :, i])
                        # obs = (obs.reshape(obs.shape[:-1] + (self.rnum_agents, -1))) / (self.obs_std + 1e-6)
                        # obs = obs.reshape(obs.shape[:-2] + (-1,))
                        # debug_print(obs.shape)

                        tot_time += time.time()

                        # Collect critic values for this step
                        values, rnn_states_critic = self.collect_critic(curr_step)

                        # Prepare data for buffer
                        data = (obs, rewards, dones, infos, 
                                values, latent_actions[..., i, :], sampled_actions[:, :, i], 
                                actions[:, :, i], action_log_probs, action_log_probs_last[:, :, i], 
                                rnn_states, rnn_states_critic, noise[..., i, :])

                        # Insert into buffer
                        self.insert(data)

                        # Update env_infos
                        # for info in infos:
                        #     for k in env_infos.keys():
                        #         if k in info[0]:
                        #             env_infos[k].append(info[0][k])

                # compute return and update network
                print(f"Total time: {time.time() - start_time}, step time: {tot_time}")
                
                reward_info = self.compute()
                if episode > start_ppo and episode % 1 == 0:
                    lrscheduler.step()
                train_infos = self.train(episode/episodes, update_actor=(episode > start_ppo), update_critic=(episode < start_ppo or episode > start_critic))
                train_infos.update(reward_info)
                print(train_infos)



                # post process
                total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
                
                print("Step", total_num_steps)
                # pprint(train_infos)


                # save model
                if ((episode + 1) % self.save_interval == 0 or episode == episodes - 1):
                    self.save(episode)

                # log information
                if episode % self.log_interval == 0:
                    end = time.time()
                    print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                            .format(self.all_args.scenario_name,
                                    self.algorithm_name,
                                    self.experiment_name,
                                    episode,
                                    episodes,
                                    total_num_steps,
                                    self.num_env_steps,
                                    int(total_num_steps / (end - start))))


                    # train_infos["average_episode_lengths"] = (self.current_length_sum/max(1, self.current_finish))
                    train_infos["average_episode_rewards"] =  np.mean(self.buffer.rewards) * self.episode_length
                    print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                    self.log_train(train_infos, total_num_steps)
                    self.log_env(self.env_infos, total_num_steps)

                    pprint({k: np.mean(v) for k, v in self.env_infos.items() if len(v) > 0})
                    pprint({k: len(v) for k, v in self.env_infos.items() if len(v) > 0})

                    self.env_infos = {k: [] for k in self.env_infos.keys()}

                # eval

                if episode % self.eval_interval == 0 and self.use_eval:
                    self.eval(total_num_steps)

    def warmup(self):
        # reset env
        obs = self.envs.reset()

        # insert obs to buffer
        self.buffer.share_obs[0] = obs.copy()
        self.buffer.obs[0] = obs.copy()


    @torch.no_grad()
    def collect_actions(self, step):
        self.trainer.prep_rollout()

        # Collect actions
        latent_actions, sampled_actions, action, action_log_prob, action_log_prob_last, rnn_states, noise = \
            self.trainer.policy.get_actions(np.concatenate(self.buffer.obs[step]),
                            np.concatenate(self.buffer.rnn_states[step]),
                            np.concatenate(self.buffer.masks[step]),
                            return_noise=True
                            )
        # debug_print(latent_actions.shape, sampled_actions.shape, action.shape, action_log_prob.shape, action_log_prob_last.shape, rnn_states.shape, noise.shape)
        
        # Split results back into per-thread arrays
        latent_actions = np.array(np.split(_t2n(latent_actions), self.n_rollout_threads))
        sampled_actions = np.array(np.split(_t2n(sampled_actions), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs_last = np.array(np.split(_t2n(action_log_prob_last), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
        noise = np.array(np.split(_t2n(noise), self.n_rollout_threads))
        
        # Actions for env
        actions_env = actions

        return latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise

    @torch.no_grad()
    def collect_critic(self, step):
        # Collect critic values
        value, rnn_states_critic = self.trainer.policy.critic_forward(
            np.concatenate(self.buffer.share_obs[step]),
            np.concatenate(self.buffer.rnn_states_critic[step]),
            None,
            np.concatenate(self.buffer.masks[step])
        )

        # Split results back into per-thread arrays
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))[:, 0]
        rnn_states_critic = np.array(np.split(_t2n(rnn_states_critic), self.n_rollout_threads))

        return values, rnn_states_critic
    
    @torch.no_grad()
    def compute(self):
        result = dict()


        self.trainer.prep_rollout()
        next_values = self.trainer.policy.get_values(
                            np.concatenate(self.buffer.share_obs[-1]),
                            None,
                            np.concatenate(self.buffer.rnn_states_critic[-1]),
                            np.concatenate(self.buffer.masks[-1]))
        next_values = np.array(np.split(_t2n(next_values), self.n_rollout_threads))[:, 0]
        compute_returns(self.buffer, next_values, self.trainer.value_normalizer)

        return result
    
    def insert(self, data):
        obs, rewards, dones, infos, values, latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, rnn_states_critic, noise = data
        
        # update env_infos if done
        dones_env = np.all(dones, axis=-1)
        if np.any(dones_env):
            for done, info in zip(dones_env, infos):
                if done:
                    self.env_infos["goal"].append(info["score_reward"])
                    if info["score_reward"] > 0:
                        self.env_infos["win_rate"].append(1)
                    else:
                        self.env_infos["win_rate"].append(0)
                    # print(info)
                    # self.env_infos["steps"].append(info["max_steps"] - info["steps_left"])

        # reset rnn and mask args for done envs
        rnn_states[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)
        masks[dones_env == True] = np.zeros(((dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)

        # print(type(obs), type(latent_actions))
        self.buffer.insert(
            obs,
            obs,
            obs,
            rnn_states,
            rnn_states_critic,
            actions=actions,
            action_log_probs=action_log_probs,
            action_log_probs_last=action_log_probs_last,
            value_preds=values,
            rewards=rewards,
            masks=masks,
            latent_actions=latent_actions,
            sampled_actions=sampled_actions,
            noise=noise
        )

    def log_env(self, env_infos, total_num_steps):
        for k, v in env_infos.items():
            if len(v) > 0:
                if self.use_wandb:
                    wandb.log({k: np.mean(v)}, step=total_num_steps)
                else:
                    self.writter.add_scalars(k, {k: np.mean(v)}, total_num_steps)    

    @torch.no_grad()
    def eval(self, total_num_steps):
        # reset envs and init rnn and mask
        eval_obs = self.eval_envs.reset()
        eval_rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
        eval_masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

        # init eval goals
        num_done = 0
        eval_goals = np.zeros(self.all_args.eval_episodes)
        eval_win_rates = np.zeros(self.all_args.eval_episodes)
        eval_steps = np.zeros(self.all_args.eval_episodes)
        step = 0
        quo = self.all_args.eval_episodes // self.n_eval_rollout_threads
        rem = self.all_args.eval_episodes % self.n_eval_rollout_threads
        done_episodes_per_thread = np.zeros(self.n_eval_rollout_threads, dtype=int)
        eval_episodes_per_thread = done_episodes_per_thread + quo
        eval_episodes_per_thread[:rem] += 1
        unfinished_thread = (done_episodes_per_thread != eval_episodes_per_thread)

        # loop until enough episodes
        while num_done < self.all_args.eval_episodes and step < self.episode_length:
            # get actions
            self.trainer.prep_rollout()

            # [n_envs, n_agents, ...] -> [n_envs*n_agents, ...]
            # debug_print(eval_obs.shape)
            # debug_print(np.concatenate(eval_obs).shape, np.concatenate(eval_rnn_states).shape, np.concatenate(eval_masks).shape)
            eval_actions, eval_rnn_states = self.trainer.policy.act(
                np.concatenate(eval_obs),
                np.concatenate(eval_rnn_states),
                np.concatenate(eval_masks),
                deterministic=self.all_args.eval_deterministic
            )
            
            # [n_envs*n_agents, ...] -> [n_envs, n_agents, ...]
            eval_actions = np.array(np.split(_t2n(eval_actions), self.n_eval_rollout_threads))
            # debug_print(eval_actions.shape)
            eval_rnn_states = np.array(np.split(_t2n(eval_rnn_states), self.n_eval_rollout_threads))

            eval_actions_env = [eval_actions[idx, :, 0] for idx in range(self.n_eval_rollout_threads)]

            # step
            eval_obs, eval_rewards, eval_dones, eval_infos = self.eval_envs.step(eval_actions[:, 0])

            # update goals if done
            eval_dones_env = np.all(eval_dones, axis=-1)
            eval_dones_unfinished_env = eval_dones_env[unfinished_thread]
            if np.any(eval_dones_unfinished_env):
                for idx_env in range(self.n_eval_rollout_threads):
                    if unfinished_thread[idx_env] and eval_dones_env[idx_env]:
                        eval_goals[num_done] = eval_infos[idx_env]["score_reward"]
                        eval_win_rates[num_done] = 1 if eval_infos[idx_env]["score_reward"] > 0 else 0
                        eval_steps[num_done] = eval_infos[idx_env]["max_steps"] - eval_infos[idx_env]["steps_left"]
                        # print("episode {:>2d} done by env {:>2d}: {}".format(num_done, idx_env, eval_infos[idx_env]["score_reward"]))
                        num_done += 1
                        done_episodes_per_thread[idx_env] += 1
            unfinished_thread = (done_episodes_per_thread != eval_episodes_per_thread)

            # reset rnn and masks for done envs
            eval_rnn_states[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            eval_masks = np.ones((self.all_args.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
            eval_masks[eval_dones_env == True] = np.zeros(((eval_dones_env == True).sum(), self.num_agents, 1), dtype=np.float32)
            step += 1

        # get expected goal
        eval_goal = np.mean(eval_goals)
        eval_win_rate = np.mean(eval_win_rates)
        eval_step = np.mean(eval_steps)
    
        # log and print
        print("eval expected goal is {}.".format(eval_goal))
        if self.use_wandb:
            wandb.log({"eval_goal": eval_goal}, step=total_num_steps)
            wandb.log({"eval_win_rate": eval_win_rate}, step=total_num_steps)
            wandb.log({"eval_step": eval_step}, step=total_num_steps)
        else:
            self.writter.add_scalars("eval_goal", {"expected_goal": eval_goal}, total_num_steps)
            self.writter.add_scalars("eval_win_rate", {"eval_win_rate": eval_win_rate}, total_num_steps)
            self.writter.add_scalars("eval_step", {"expected_step": eval_step}, total_num_steps)

    @torch.no_grad()
    def render(self):        
        # reset envs and init rnn and mask
        render_env = self.envs

        # init goal
        render_goals = np.zeros(self.all_args.render_episodes)
        for i_episode in range(self.all_args.render_episodes):
            render_obs = render_env.reset()
            render_rnn_states = np.zeros((self.n_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            render_masks = np.ones((self.n_rollout_threads, self.num_agents, 1), dtype=np.float32)

            if self.all_args.save_gifs:        
                frames = []
                image = self.envs.envs[0].env.unwrapped.observation()[0]["frame"]
                frames.append(image)

            render_dones = False
            while not np.any(render_dones):
                self.trainer.prep_rollout()
                render_actions, render_rnn_states = self.trainer.policy.act(
                    np.concatenate(render_obs),
                    np.concatenate(render_rnn_states),
                    np.concatenate(render_masks),
                    deterministic=True
                )

                # [n_envs*n_agents, ...] -> [n_envs, n_agents, ...]
                render_actions = np.array(np.split(_t2n(render_actions), self.n_rollout_threads))
                render_rnn_states = np.array(np.split(_t2n(render_rnn_states), self.n_rollout_threads))

                render_actions_env = [render_actions[idx] for idx in range(self.n_rollout_threads)]
                # print(render_actions_env, render_actions)

                # step
                render_obs, render_rewards, render_dones, render_infos = render_env.step(render_actions_env)

                # append frame
                if self.all_args.save_gifs:        
                    image = render_infos[0]["frame"]
                    frames.append(image)
            
            # print goal
            render_goals[i_episode] = render_rewards[0, 0]
            print("goal in episode {}: {}".format(i_episode, render_rewards[0, 0]))
            print(f'reward {render_infos[0]["score_reward"]}')

            # save gif
            if self.all_args.save_gifs:
                imageio.mimsave(
                    uri="{}/episode{}.gif".format(str(self.gif_dir), i_episode),
                    ims=frames,
                    format="GIF",
                    duration=self.all_args.ifi,
                )
        
        print("expected goal: {}".format(np.mean(render_goals)))
