import copy
import gc
import sys
import time
from debug import debug_print
import numpy as np
from onpolicy.algorithms.diffusion_ac.datasets import D4RL_Dataset
import psutil
import torch
import wandb
import time 
import matplotlib.pyplot as plt 
from pprint import pprint
from torch.utils.data import DataLoader
from onpolicy.algorithms.utils.scheduler import CosineAnnealingWarmupRestarts

from onpolicy.runner.shared.base_runner import Runner
import onpolicy.algorithms.gail.gail_utils as gail_utils
from onpolicy.utils.running_mean_std import RunningMeanStd
from onpolicy.utils.reward_scaling import RunningRewardScaler

def _t2n(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()


def compute_returns(buffer, next_value, value_normalizer=None):
    """
    Compute returns either as discounted sum of rewards, or using GAE.
    :param next_value: (np.ndarray) value predictions for the step after the last episode step.
    :param value_normalizer: (PopArt) If not None, PopArt value normalizer instance.
    """
    rewards = np.concatenate([buffer.rewards], axis=-1)#[..., np.newaxis].transpose(0, 3, 1, 2, 4)
    # debug_print('fa', rewards.shape, buffer.rewards.shape)
    value_preds = buffer.value_preds.copy()
    # debug_print(next_value.shape, value_preds.shape, buffer.returns.shape)
    # debug_print(buffer.returns)
    # debug_print(value_preds.shape, next_value.shape)
    value_preds[-1] = next_value
    # value_preds = value_preds.transpose(0, 3, 1, 2)[..., np.newaxis]
    masks = np.concatenate([buffer.masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    bad_masks = np.concatenate([buffer.bad_masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    returns = buffer.returns.copy()#.transpose(0, 3, 1, 2)[..., np.newaxis]
    # debug_print(buffer.returns.shape, masks.shape, bad_masks.shape)

    # rewards = np.concatenate(rewards)
    # # debug_print(value_preds.shape, next_value.shape)
    # value_preds = np.concatenate(value_preds)
    # # debug_print(value_preds.shape, next_value.shape)
    # masks = np.concatenate(masks)
    # bad_masks = np.concatenate(bad_masks)
    # returns = np.concatenate(returns)

    if buffer._use_proper_time_limits:
        if buffer._use_gae:
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + value_preds[step]
    else:
        if buffer._use_gae:
            # debug_print(value_preds.shape, next_value.shape)
            # value_preds[-1] = next_value[:, :, 0, None]
            gae = 0
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    # debug_print(value_preds.shape, masks.shape)
                    delta = rewards[step] + buffer.gamma * value_normalizer.denormalize(
                        value_preds[step + 1][..., 0])[..., None] * masks[step + 1] \
                            - value_normalizer.denormalize(value_preds[step][..., 0])[..., None]
                    # debug_print(value_preds[step][0, :, 0].shape, gae, masks.shape, delta.shape, value_preds.shape, rewards.shape, masks.shape, value_normalizer.denormalize(value_preds[step][..., 0])[..., None].shape)
                    gae = delta + buffer.gamma * buffer.gae_lambda * masks[step + 1] * gae
                    # debug_print(gae.shape, value_normalizer.denormalize(value_preds[step]).shape)
                    returns[step] = gae + value_normalizer.denormalize(value_preds[step][..., 0])[..., None]
                else:
                    delta = rewards[step] + buffer.gamma * value_preds[step + 1] * \
                            masks[step + 1] - value_preds[step]
                    gae = delta + buffer.gamma * buffer.gae_lambda * masks[step + 1] * gae
                    returns[step] = gae + value_preds[step]

    # debug_print(returns.shape)
    returns = returns.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)
    value_preds = value_preds.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)

    assert (returns.shape == buffer.returns.shape), (returns.shape, buffer.returns.shape)
    assert (value_preds.shape == buffer.value_preds.shape), (value_preds.shape, buffer.value_preds.shape)

    buffer.returns = returns
    buffer.value_preds = value_preds


class Robomimic_DiffRunner(Runner):
    def __init__(self, config):
        super().__init__(config)
        self.trainer.policy.obs_rms = RunningMeanStd()
        self.train_dataset = config['train_dataset']
        self.val_dataset = config['val_dataset']

    def clone(self):
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        clone_episodes = self.all_args.clone_episodes
        batch_size = 256
        act_step = self.all_args.act_step
        # lrscheduler = CosineAnnealingWarmupRestarts(self.trainer.policy.clone_actor_optimizer, first_cycle_steps=10000, cycle_mult=1.0, max_lr=self.all_args.lr, min_lr=self.all_args.lr * 0.05, warmup_steps=100)
        train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
        max_reward = -1
        
        for epoch in range(clone_episodes):
            start_time = time.time()
            loss_sum = 0
            grad_norm_sum = 0
            for obs, actions_batch in train_loader:
                actions_batch = actions_batch / self.all_args.logit_scaling
                batch_size = obs.shape[0]
                # actions_batch_onehot = np.eye(self.buffer.act_size)[actions]
                # actions_batch = actions_batch_onehot.reshape(batch_size, -1)
                # debug_print('fa', obs.shape, actions_batch.shape)
                
                data = obs.reshape(batch_size * act_step, -1), actions_batch.reshape(batch_size * act_step, -1), None, None, None
                # debug_print(data[0].shape, data[1].shape)
                bc_loss, actor_grad_norm = self.trainer.bc_clone(data)
                loss_sum += bc_loss
                grad_norm_sum += actor_grad_norm
            # lrscheduler.step()
            end_time = time.time()
            print('epoch', epoch, 'loss', loss_sum / len(train_loader), 'grad_norm', grad_norm_sum / len(train_loader), 'time', end_time - start_time)
            self.log_train({'bc_loss': loss_sum / len(train_loader), 'bc_grad_norm': grad_norm_sum / len(train_loader)}, epoch)
            if epoch % 10 == 0:
                self.save()
                if self.all_args.save_dir is not None:
                    self.save_pretrained(self.all_args.save_dir)
                print('save')
            
            if epoch % 5 == 0:
                val_loss_sum = 0
                with torch.no_grad():
                    for obs, actions_batch in val_loader:
                        actions_batch = actions_batch / self.all_args.logit_scaling
                        batch_size = obs.shape[0]
                        # actions_batch_onehot = np.eye(self.buffer.act_size)[actions]
                        # actions_batch = actions_batch_onehot.reshape(batch_size, -1)
                        # debug_print('fa', obs.shape, actions_batch.shape)
                        
                        data = obs.reshape(batch_size * act_step, -1), actions_batch.reshape(batch_size * act_step, -1), None, None, None
                        # debug_print(data[0].shape, data[1].shape)
                        bc_loss = self.trainer.bc_loss(data)
                        val_loss_sum += bc_loss
                print('val', epoch, 'loss', val_loss_sum / len(val_loader))
                self.log_train({'val_bc_loss': val_loss_sum / len(val_loader)}, epoch)
            if epoch % 10 == 0:
                self.save()
                if self.all_args.save_dir is not None:
                    self.save_pretrained(self.all_args.save_dir)
                print('save')
            
            # if epoch % 5 == 0:
            #     val_loss_sum = 0
            #     with torch.no_grad():
            #         for obs, actions_batch in val_loader:
            #             actions_batch = actions_batch / self.all_args.logit_scaling
            #             data = obs, actions_batch, None, None, None
            #             bc_loss = self.trainer.bc_loss(data)
            #             val_loss_sum += bc_loss
            #     print('val', epoch, 'loss', val_loss_sum / len(val_loader))
            #     self.log_train({'val_bc_loss': val_loss_sum / len(val_loader)}, epoch)
                # self.save_pretrained(self.all_args.save_dir)
                # print('save model')
            if (epoch + 1) % self.all_args.eval_interval == 0 and self.use_eval:
                reward = self.eval(epoch)
                # self.save_episode(epoch)
                if reward >= max_reward - (1e-3):
                    max_reward = reward
                    print('save model, reward', reward)
                    self.save_max()
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        return max_reward
    
    def save_episode(self, episode=0):
        policy_actor = self.trainer.policy.actor
        torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_" + str(episode) + ".pt")
        policy_critic = self.trainer.policy.critic
        torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_" + str(episode) + ".pt")
    
    def save_max(self):
        policy_actor = self.trainer.policy.actor
        torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_max.pt")
        policy_critic = self.trainer.policy.critic
        torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_max.pt")


    def save_pretrained(self, dir):
        """Save policy's actor and critic networks."""
        policy_actor = self.trainer.policy.actor
        torch.save(policy_actor.state_dict(), str(dir) + "/actor.pt")
        policy_critic = self.trainer.policy.critic
        torch.save(policy_critic.state_dict(), str(dir) + "/critic.pt")

    def run(self):
        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads
        clone_episodes = self.all_args.clone_episodes

        start_ppo = 1
        start_critic = 0
        self.trainer.policy.actor.diffusion.update_eta(-1)
        if self.all_args.eta >= 0.01:
            self.trainer.policy.actor.diffusion.update_eta(self.all_args.eta) #episode / clone_episodes)
        
        if clone_episodes > 0:
            self.clone()
            self.save()
            # self.save_pretrained(self.all_args.save_dir)
            # return
            gc.collect()
            print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        
        debug_print(self.all_args.model_dir)

        if clone_episodes > 0 or self.all_args.model_dir is not None or True:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 1, 1, total_iters=1)
            lrscheduler.step()
            self.trainer.policy.actor.act.update_logstd(self.all_args.initial_logstd)
        else:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 1, 1, total_iters=1000)
            start_ppo = -1
            start_critic = -1

        
        self.warmup()
        env_infos = {
            "episode_length": [],
            "episode_return": []
        }
        eval_step = 0
        obs_lower = self.all_args.obs_lower
        obs_upper = self.all_args.obs_upper
        # debug_print(obs_lower, obs_upper)
        # obs_rms=self.trainer.policy.obs_rms
        # ret_rms=self.trainer.policy.ret_rms
        for episode in range(episodes):
            self.lengths_track = [0] * self.n_rollout_threads
            self.rewards_track = [0] * self.n_rollout_threads
            self.current_finish = 0
            self.current_reward_sum = 0
            self.current_length_sum = 0

            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)
            # debug_print(self.all_args.n_rollout_threads)
            
            # self.trainer.policy.actor.diffusion.update_eta(episode / episodes)
            tot_time = 0
            cur_time = time.time()

            act_step = self.all_args.act_step

            for step in range(0, self.episode_length, act_step):
                # Sample actions for next act_step steps at once
                latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise = \
                    self.collect_actions(step)
                latent_actions = latent_actions.reshape(*latent_actions.shape[:-1], act_step, -1)
                # debug_print('noise', noise.shape)
                noise = noise.reshape(*noise.shape[:-1], act_step, -1)
                # debug_print('action_env', actions_env.shape, 'latent_actions', latent_actions.shape, 'sampled_actions', sampled_actions.shape, 'actions', actions.shape, 'action_log_probs', action_log_probs.shape, 'action_log_probs_last', action_log_probs_last.shape, 'rnn_states', rnn_states.shape, 'noise', noise.shape)
                
                # Execute actions one by one and collect critic values for each step
                for i in range(min(act_step, self.episode_length - step)):
                    curr_step = step + i
                    
                    # Execute single action
                    _start = time.time()
                    obs, rewards, dones, infos = self.envs.step(actions_env[:, :, i])
                    tot_time += time.time() - _start

                    # Update observation statistics
                    self.trainer.policy.obs_rms.update(obs[:, 0, :])
                    
                    if self.all_args.norm_obs:
                        obs = np.clip((obs - self.trainer.policy.obs_rms.mean) / 
                                     np.sqrt(self.trainer.policy.obs_rms.var + 1e-6), 
                                     -self.all_args.clip_obs, self.all_args.clip_obs)

                    # Collect critic values for this step
                    values, rnn_states_critic = self.collect_critic(curr_step)

                    # Prepare data for buffer
                    # debug_print('action', actions[:, :, i].shape, 'noise', noise.shape)
                    data = (obs, rewards, dones, infos, 
                            values, latent_actions[..., i, :], sampled_actions[:, :, i], 
                            actions[:, :, i], action_log_probs, action_log_probs_last[:, :, i, None], 
                            rnn_states, rnn_states_critic, noise[..., i, :])

                    # Insert into buffer
                    self.insert(data)

                    # Update env_infos
                    for info in infos:
                        for k in env_infos.keys():
                            if k in info[0]:
                                env_infos[k].append(info[0][k])

            # compute return and update network
            print('spent on env:', tot_time, 'tot:', time.time() - cur_time)
            reward_info = self.compute()
            if episode > start_ppo and episode % 1 == 0:
                lrscheduler.step()
            train_infos = self.train(episode/episodes, update_actor=(episode > start_ppo), update_critic=(episode < start_ppo or episode > start_critic))
            train_infos.update(reward_info)
            print(train_infos)

            # post process
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
            
            print("Step", total_num_steps)
            # pprint(train_infos)


            # save model
            if ((episode + 1) % self.save_interval == 0 or episode == episodes - 1):
                self.save(episode)

            # log information
            if episode % self.log_interval == 0:
                end = time.time()
                print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                        .format(self.all_args.scenario_name,
                                self.algorithm_name,
                                self.experiment_name,
                                episode,
                                episodes,
                                total_num_steps,
                                self.num_env_steps,
                                int(total_num_steps / (end - start))))


                # train_infos["average_episode_lengths"] = (self.current_length_sum/max(1, self.current_finish))
                train_infos["average_episode_rewards"] =  np.mean(self.buffer.rewards) * self.episode_length
                print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                # print("average episode lengths is {}".format(train_infos["average_episode_lengths"]))
                # print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                # print(train_infos)
                self.log_train(train_infos, total_num_steps)
                self.log_env(env_infos, total_num_steps)

                pprint({k: np.mean(v) for k, v in env_infos.items() if len(v) > 0})

                env_infos = {k: [] for k in env_infos.keys()}

            # eval
            "Current eval receives dataset obs and rollout obs to give results"
            if episode % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)

    def eval(self, total_steps):
        episodes = self.all_args.eval_episodes
        act_step = self.all_args.act_step
        env_infos = {
            "episode_length": [],
            "episode_return": [],
        }

        for episode in range(episodes):
            obs = self.eval_envs.reset()
            rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            rnn_states_critic = np.zeros((self.n_eval_rollout_threads, self.num_agents, *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)
            masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

            start = time.time()

            for step in range(0, self.episode_length, act_step):
                # Sample actions for next act_step steps at once
                _, _, actions, _, _, rnn_states, _ = \
                    self.trainer.policy.get_actions(np.concatenate(obs),
                                    np.concatenate(rnn_states),
                                    np.concatenate(masks),
                                    deterministic=True,
                                    return_noise=True
                                )
                
                # debug_print(actions.shape)
                actions_env = np.array(np.split(_t2n(actions), self.n_eval_rollout_threads))
                # debug_print(actions.shape)
                rnn_states = np.array(np.split(_t2n(rnn_states), self.n_eval_rollout_threads))

                # debug_print(actions_env.shape)

                # Execute actions one by one
                for i in range(min(act_step, self.episode_length - step)):
                    obs, rewards, dones, infos = self.eval_envs.step(actions_env[:, :, i])

                    rnn_states[dones == True] = np.zeros(((dones == True).sum(), self.recurrent_N, self.hidden_size), dtype=np.float32)
                    masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
                    masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)

                    if self.all_args.norm_obs:
                        obs = np.clip((obs - self.trainer.policy.obs_rms.mean) / 
                                    np.sqrt(self.trainer.policy.obs_rms.var + 1e-6), 
                                    -self.all_args.clip_obs, self.all_args.clip_obs)

                    for info in infos:
                        for k in env_infos.keys():
                            if k in info[0]:
                                env_infos[k].append(info[0][k])
            
            total_num_steps = self.episode_length * self.n_eval_rollout_threads
            
            print("Step", total_num_steps)
            end = time.time()
            print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                    .format(self.all_args.scenario_name,
                            self.algorithm_name,
                            self.experiment_name,
                            episode,
                            episodes,
                            total_num_steps,
                            self.num_env_steps,
                            int(total_num_steps / (end - start))))

        env_infos = {f"eval-{k}": v for k, v in env_infos.items()}
        self.log_env(env_infos, total_steps)
        pprint({k: np.mean(v) for k, v in env_infos.items() if len(v) > 0})
        
        return np.mean(env_infos["eval-episode_return"])

    
    def warmup(self):
        # reset env
        obs = self.envs.reset()

        # replay buffer
        if self.use_centralized_V:
            share_obs = obs.reshape(self.n_rollout_threads, -1)
            share_obs = np.expand_dims(share_obs, 1).repeat(1, axis=1)
        else:
            share_obs = obs

        self.buffer.share_obs[0] = share_obs.copy()
        self.buffer.obs[0] = obs.copy()
    
    @torch.no_grad()
    def collect_actions(self, step):
        self.trainer.prep_rollout()

        # Collect actions
        latent_actions, sampled_actions, action, action_log_prob, action_log_prob_last, rnn_states, noise = \
            self.trainer.policy.get_actions(np.concatenate(self.buffer.obs[step]),
                            np.concatenate(self.buffer.rnn_states[step]),
                            np.concatenate(self.buffer.masks[step]),
                            return_noise=True
                            )
        
        # Split results back into per-thread arrays
        latent_actions = np.array(np.split(_t2n(latent_actions), self.n_rollout_threads))
        sampled_actions = np.array(np.split(_t2n(sampled_actions), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs_last = np.array(np.split(_t2n(action_log_prob_last), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
        # debug_print('action', action.shape, 'noise', noise.shape)
        noise = np.array(np.split(_t2n(noise), self.n_rollout_threads))
        
        # Actions for env
        actions_env = actions
        # debug_print('actions_env', actions_env.shape)
        # debug_print('collect', 'latent_actions', latent_actions.shape, 'sampled_actions', sampled_actions.shape, 'action', action.shape, 'action_log_prob', action_log_prob.shape, 'action_log_prob_last', action_log_prob_last.shape, 'rnn_states', rnn_states.shape, 'noise', noise.shape)

        return latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise

    @torch.no_grad()
    def collect_critic(self, step):
        # Collect critic values
        value, rnn_states_critic = self.trainer.policy.critic_forward(
            np.concatenate(self.buffer.share_obs[step]),
            np.concatenate(self.buffer.rnn_states_critic[step]),
            None,
            np.concatenate(self.buffer.masks[step])
        )

        # Split results back into per-thread arrays
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))[:, 0]
        rnn_states_critic = np.array(np.split(_t2n(rnn_states_critic), self.n_rollout_threads))

        return values, rnn_states_critic

    # def collect(self, step):
    #     # Get actions first
    #     latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise = \
    #         self.collect_actions(step)
        
    #     # Then get critic values
    #     values, rnn_states_critic = self.collect_critic(step, np.concatenate(latent_actions))

    #     return values, latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, rnn_states_critic, actions_env, noise

    def insert(self, data):
        # obs, rewards, dones, infos, values, latent_actions, sampled_actions, actions, action_log_probs, rnn_states, rnn_states_critic = data
        obs, rewards, dones, infos, values, latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, rnn_states_critic, noise = data
        
        rnn_states[dones == True] = np.zeros(((dones == True).sum(), self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones == True] = np.zeros(((dones == True).sum(), *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)
        masks = np.ones((self.n_rollout_threads, 1, 1), dtype=np.float32)
        masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)

        if self.use_centralized_V:
            share_obs = obs.reshape(self.n_rollout_threads, -1)
            share_obs = np.expand_dims(share_obs, 1).repeat(1, axis=1)
        else:
            share_obs = obs


        self.buffer.insert(
            obs,
            obs,
            obs,
            rnn_states,
            rnn_states_critic,
            actions=actions,
            action_log_probs=action_log_probs,
            action_log_probs_last=action_log_probs_last,
            value_preds=values,
            rewards=rewards,
            masks=masks,
            latent_actions=latent_actions,
            sampled_actions=sampled_actions,
            noise=noise
        )

        for i in range(self.n_rollout_threads):
            self.rewards_track[i] += rewards[i, 0, 0]
            self.lengths_track[i] += 1
            if dones[i, 0]:
                self.current_finish += 1
                self.current_reward_sum += self.rewards_track[i]
                self.current_length_sum += self.lengths_track[i]
                self.rewards_track[i] = 0
                self.lengths_track[i] = 0

        # self.buffer.insert(share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, values, rewards, masks, latent_actions=latent_actions, sampled_actions=sampled_actions)

    @torch.no_grad()
    def compute(self):
        result = dict()


        self.trainer.prep_rollout()
        next_values = self.trainer.policy.get_values(
                            np.concatenate(self.buffer.share_obs[-1]),
                            None,
                            np.concatenate(self.buffer.rnn_states_critic[-1]),
                            np.concatenate(self.buffer.masks[-1]))
        next_values = np.array(np.split(_t2n(next_values), self.n_rollout_threads))[:, 0]
        compute_returns(self.buffer, next_values, self.trainer.value_normalizer)
        # if self.all_args.norm_reward:
        #     self.buffer.returns = np.clip((self.buffer.returns) / np.sqrt(self.trainer.policy.ret_rms.var + 1e-6), -self.all_args.clip_reward, self.all_args.clip_reward)
        #     # debug_print(self.trainer.policy.ret_rms.var, self.trainer.policy.ret_rms.mean, self.trainer.policy.ret_rms.var.shape)
        #     self.trainer.policy.ret_rms.update(self.buffer.returns)

        return result

    # def save(self, episode=0):
    #     """Save policy's actor and critic networks."""
    #     if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
    #         self.policy.save(self.save_dir, episode)
    #     else:
    #         if hasattr(self.trainer.policy, 'actor'):
    #             policy_actor = self.trainer.policy.actor
    #             torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor.pt")
    #             torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_{}.pt".format(episode))
    #         if hasattr(self.trainer.policy, 'critic'):
    #             policy_critic = self.trainer.policy.critic
    #             torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic.pt")
    #             torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_{}.pt".format(episode))

    # def restore(self, model_dir):
    #     """Restore policy's networks from a saved model."""
    #     if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
    #         self.policy.restore(model_dir)
    #     else:
    #         policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor.pt')
    #         self.policy.actor.load_state_dict(policy_actor_state_dict)
    #         if not self.all_args.use_render:
    #             policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic.pt')
    #             self.policy.critic.load_state_dict(policy_critic_state_dict)