import copy
import gc
import sys
import time
from debug import debug_print
import numpy as np
from onpolicy.algorithms.diffusion_ac.datasets import D4RL_Dataset
import psutil
import torch
import wandb
import time 
import matplotlib.pyplot as plt 
from pprint import pprint
from torch.utils.data import DataLoader
from tqdm import tqdm

from onpolicy.runner.shared.base_runner import Runner
import onpolicy.algorithms.gail.gail_utils as gail_utils
from onpolicy.utils.running_mean_std import RunningMeanStd

def _t2n(x):
    if isinstance(x, np.ndarray):
        return x
    return x.detach().cpu().numpy()


def compute_returns(buffer, next_value, value_normalizer=None):
    """
    Compute returns either as discounted sum of rewards, or using GAE.
    :param next_value: (np.ndarray) value predictions for the step after the last episode step.
    :param value_normalizer: (PopArt) If not None, PopArt value normalizer instance.
    """
    rewards = np.concatenate([buffer.rewards], axis=-1)#[..., np.newaxis].transpose(0, 3, 1, 2, 4)
    # debug_print('fa', rewards.shape, buffer.rewards.shape)
    value_preds = buffer.value_preds.copy()
    # debug_print(next_value.shape, value_preds.shape, buffer.returns.shape)
    # debug_print(buffer.returns)
    # debug_print(value_preds.shape, next_value.shape)
    value_preds[-1] = next_value
    # value_preds = value_preds.transpose(0, 3, 1, 2)[..., np.newaxis]
    masks = np.concatenate([buffer.masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    bad_masks = np.concatenate([buffer.bad_masks], axis=-1)#.transpose(0, 3, 1, 2)[..., np.newaxis]
    returns = buffer.returns.copy()#.transpose(0, 3, 1, 2)[..., np.newaxis]
    # debug_print(buffer.returns.shape, masks.shape, bad_masks.shape)

    # rewards = np.concatenate(rewards)
    # # debug_print(value_preds.shape, next_value.shape)
    # value_preds = np.concatenate(value_preds)
    # # debug_print(value_preds.shape, next_value.shape)
    # masks = np.concatenate(masks)
    # bad_masks = np.concatenate(bad_masks)
    # returns = np.concatenate(returns)

    if buffer._use_proper_time_limits:
        if buffer._use_gae:
            gae = 0
            mul = buffer.gamma * buffer.gae_lambda
            # value_preds[-1] = next_value
            if buffer._use_popart or buffer._use_valuenorm:
                denormed_values = value_normalizer.denormalize(value_preds)
                delta = rewards + buffer.gamma * denormed_values[1:] * masks[1:] \
                        - denormed_values[:-1]
            else:
                delta = rewards + buffer.gamma * value_preds[1:] * masks[1:] - value_preds[:-1]
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    gae = delta[step] + mul * gae * masks[step + 1]
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + denormed_values[step]
                else:
                    gae = delta[step] + mul * masks[step + 1] * gae
                    gae *= bad_masks[step + 1]
                    returns[step] = gae + value_preds[step]
    else:
        if buffer._use_gae:
            # debug_print(value_preds.shape, next_value.shape)
            # value_preds[-1] = next_value[:, :, 0, None]
            gae = 0
            for step in reversed(range(rewards.shape[0])):
                if buffer._use_popart or buffer._use_valuenorm:
                    # debug_print(value_preds.shape, masks.shape)
                    delta = rewards[step] + buffer.gamma * value_normalizer.denormalize(
                        value_preds[step + 1][..., 0])[..., None] * masks[step + 1] \
                            - value_normalizer.denormalize(value_preds[step][..., 0])[..., None]
                    # debug_print(value_preds[step][0, :, 0].shape, gae, masks.shape, delta.shape, value_preds.shape, rewards.shape, masks.shape, value_normalizer.denormalize(value_preds[step][..., 0])[..., None].shape)
                    gae = delta + buffer.gamma * buffer.gae_lambda * masks[step + 1] * gae
                    # debug_print(gae.shape, value_normalizer.denormalize(value_preds[step]).shape)
                    returns[step] = gae + value_normalizer.denormalize(value_preds[step][..., 0])[..., None]
                else:
                    delta = rewards[step] + buffer.gamma * value_preds[step + 1] * \
                            masks[step + 1] - value_preds[step]
                    gae = delta + buffer.gamma * buffer.gae_lambda * masks[step + 1] * gae
                    returns[step] = gae + value_preds[step]

    # debug_print(returns.shape)
    returns = returns.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)
    value_preds = value_preds.reshape(buffer.episode_length + 1, 1, buffer.n_rollout_threads, buffer.rnum_agents).transpose(0, 2, 3, 1)

    assert (returns.shape == buffer.returns.shape), (returns.shape, buffer.returns.shape)
    assert (value_preds.shape == buffer.value_preds.shape), (value_preds.shape, buffer.value_preds.shape)

    buffer.returns = returns
    buffer.value_preds = value_preds


class D4RL_DiffRunner(Runner):
    def __init__(self, config):
        super().__init__(config)
        self.trainer.policy.obs_rms = RunningMeanStd()
        self.trainer.policy.ret_rms = RunningMeanStd()
        self.dataset = config['dataset']
        self.dataset : D4RL_Dataset

    def clone(self):
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        clone_episodes = self.all_args.clone_episodes

        dataloader = DataLoader(self.dataset, batch_size=256, shuffle=True)
        # if clone_episodes > 0:
        #     # debug_print(self.log_interval)
        #     # vlt = Vault(rel_dir="/root/vaults", vault_name="smacv2", vault_uid="protoss20v23_004logits")
        #     vlt = Vault(rel_dir="/home/username/vaults", vault_name="gfootball", vault_uid=self.all_args.vault_uid)
        #     all_data = vlt.read()
        #     print(jax.tree_map(lambda x: x.shape, all_data.experience))
        #     batch_size = 256

        #     buffer = fbx.make_trajectory_buffer(
        #         # Sampling parameters
        #         sample_batch_size=batch_size,
        #         sample_sequence_length=1,
        #         period=1,
        #         # Not important in this example, as we are not adding to the buffer
        #         max_length_time_axis=1_000_000,
        #         min_length_time_axis=1,
        #         add_batch_size=1,
        #     )
            
        #     offline_data = all_data.experience
        #     debug_print(offline_data['obs'].shape)
        #     print(f"Episode return: {offline_data['reward'].mean()*self.episode_length}")

        #     buffer_sample = jax.jit(buffer.sample)
        #     print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        #     seed = 0
        batch_size = 256
        for episode in range(clone_episodes):
            loss_list = []
            pbar = tqdm(dataloader)
            for batch in pbar:
                obs, actions_batch = batch
                # debug_print(obs.shape, actions_batch.shape)
                # obs = obs[0]
                batch_size = obs.shape[0]
                actions_batch = actions_batch / self.all_args.logit_scaling
                # debug_print(obs.shape, actions_batch.shape)
                data = obs.reshape(batch_size * self.all_args.act_step, -1), actions_batch.reshape(batch_size * self.all_args.act_step, -1), None, None, None
                bc_loss, actor_grad_norm = self.trainer.bc_clone(data)
                loss_list.append(bc_loss.item())
                pbar.set_postfix(BC_Loss=np.mean(loss_list[-100:]))
            if episode % 1 == 0:
                print(episode)
                self.eval(episode)
                
        # for episode in range(clone_episodes):
        #     obs, actions_batch = self.dataset.sample(batch_size)
        #     obs = obs[0]
        #     actions_batch = actions_batch[0] / self.all_args.logit_scaling
        #     # debug_print(obs.shape, actions_batch.shape)
        #     # actions_batch_onehot = np.eye(self.buffer.act_size)[actions]
        #     # actions_batch = actions_batch_onehot.reshape(batch_size, -1)
            
        #     data = obs, actions_batch, None, None, None
        #     bc_loss, actor_grad_norm = self.trainer.bc_clone(data)
        #     self.buffer.obs[0][0] = obs[0]
        #     if episode % 100 == 0:
        #         print(bc_loss, actor_grad_norm, episode)
        #         print(self.all_args.eta)
        print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
    def run(self):
        start = time.time()
        episodes = int(self.num_env_steps) // self.episode_length // self.n_rollout_threads
        clone_episodes = self.all_args.clone_episodes

        start_ppo = 5
        start_critic = 4
        if self.all_args.eta >= 0.01:
            self.trainer.policy.actor.diffusion.update_eta(self.all_args.eta) #episode / clone_episodes)
        
        if clone_episodes > 0:
            self.clone()
            gc.collect()
            print('memory usage:', psutil.Process().memory_info().rss / (1024 * 1024))
        
        if clone_episodes > 0:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 3e-1, 3e-1, total_iters=100)
            lrscheduler.step()
        else:
            lrscheduler = torch.optim.lr_scheduler.LinearLR(self.trainer.policy.actor_optimizer, 1, 1, total_iters=5)
            # lrscheduler.step()
            start_ppo = -1
            start_critic = -1

        self.warmup()
        env_infos = {
            "episode_length": [],
            "episode_return": []
        }
        returns = np.zeros(self.n_rollout_threads)
        
        act_step = self.all_args.act_step

        for episode in range(episodes):
            self.lengths_track = [0] * self.n_rollout_threads
            self.rewards_track = [0] * self.n_rollout_threads
            self.current_finish = 0
            self.current_reward_sum = 0
            self.current_length_sum = 0

            _obs_rms = copy.deepcopy(self.trainer.policy.obs_rms)
            _ret_rms = copy.deepcopy(self.trainer.policy.ret_rms)
            if self.use_linear_lr_decay:
                self.trainer.policy.lr_decay(episode, episodes)
            
            # self.trainer.policy.actor.diffusion.update_eta(episode / episodes)
            tot_time = 0
            cur_time = time.time()

            act_step = self.all_args.act_step

            for step in range(0, self.episode_length, act_step):
                # Sample actions for next act_step steps at once
                latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise = \
                    self.collect_actions(step)
                
                latent_actions = latent_actions.reshape(*latent_actions.shape[:-1], act_step, -1)
                noise = noise.reshape(*noise.shape[:-1], act_step, -1)

                # Execute actions one by one and collect critic values for each step
                for i in range(min(act_step, self.episode_length - step)):
                    curr_step = step + i
                    
                    # Execute single action
                    _start = time.time()
                    obs, rewards, dones, infos = self.envs.step(actions_env[:, :, i])
                    tot_time += time.time() - _start

                    # Update observation statistics
                    self.trainer.policy.obs_rms.update(obs[:, 0, :])
                    
                    if self.all_args.norm_obs:
                        obs = np.clip((obs - self.trainer.policy.obs_rms.mean) / 
                                    np.sqrt(self.trainer.policy.obs_rms.var + 1e-6), 
                                    -self.all_args.clip_obs, self.all_args.clip_obs)

                    # Collect critic values for this step
                    values, rnn_states_critic = self.collect_critic(curr_step)

                    # Prepare data for buffer
                    data = (obs, rewards, dones, infos, 
                            values, latent_actions[..., i, :], sampled_actions[:, :, i], 
                            actions[:, :, i], action_log_probs, action_log_probs_last[:, :, i, None], 
                            rnn_states, rnn_states_critic, noise[..., i, :])

                    # Insert into buffer
                    self.insert(data)

                    # Update env_infos
                    for info in infos:
                        for k in env_infos.keys():
                            if k in info[0]:
                                env_infos[k].append(info[0][k])

            # compute return and update network
            print('spent on env:', tot_time, 'tot:', time.time() - cur_time)
            reward_info = self.compute()
            train_infos = self.train(episode/episodes, update_actor=(episode > start_ppo), update_critic=(episode < start_ppo or episode > start_critic))
            train_infos.update(reward_info)
            print(train_infos)
            if episode > start_ppo and (episode + 1) % 4 == 0:
                lrscheduler.step()


            # rms_change_dict = {
            #                       "obs_rms_mean_change": abs(self.trainer.policy.obs_rms.mean - _obs_rms.mean).max(),
            #                       "obs_rms_std_change": abs(np.sqrt(self.trainer.policy.obs_rms.var) - np.sqrt(_obs_rms.var)).max(),
            #                       "obs_rms_mean_min": self.trainer.policy.obs_rms.mean.min(),
            #                       "obs_rms_mean_max": self.trainer.policy.obs_rms.mean.max(),
            #                       "obs_rms_std_max": np.sqrt(self.trainer.policy.obs_rms.var).max(),
            #                       "ret_rms_mean_change": abs(self.trainer.policy.ret_rms.mean - _ret_rms.mean).max(),
            #                       "ret_rms_std_change": abs(np.sqrt(self.trainer.policy.ret_rms.var) - np.sqrt(_ret_rms.var)).max(),
            #                       "ret_rms_mean": self.trainer.policy.ret_rms.mean.mean(),
            #                       "ret_rms_std": np.sqrt(self.trainer.policy.ret_rms.var).mean(),
            # }
            # train_infos.update(rms_change_dict)

            # post process
            total_num_steps = (episode + 1) * self.episode_length * self.n_rollout_threads
            
            print("Step", total_num_steps)
            # pprint(train_infos)


            # save model
            if ((episode + 1) % self.save_interval == 0 or episode == episodes - 1):
                self.save(episode)

            # log information
            if episode % self.log_interval == 0:
                end = time.time()
                print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                        .format(self.all_args.scenario_name,
                                self.algorithm_name,
                                self.experiment_name,
                                episode,
                                episodes,
                                total_num_steps,
                                self.num_env_steps,
                                int(total_num_steps / (end - start))))


                # train_infos["average_episode_lengths"] = (self.current_length_sum/max(1, self.current_finish))
                train_infos["average_episode_rewards"] =  np.mean(self.buffer.rewards) * self.episode_length
                print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                # print("average episode lengths is {}".format(train_infos["average_episode_lengths"]))
                # print("average episode rewards is {}".format(train_infos["average_episode_rewards"]))
                # print(train_infos)
                self.log_train(train_infos, total_num_steps)
                self.log_env(env_infos, total_num_steps)

                pprint({k: np.mean(v) for k, v in env_infos.items() if len(v) > 0})

                env_infos = {k: [] for k in env_infos.keys()}

            # eval
            "Current eval receives dataset obs and rollout obs to give results"
            if episode % self.eval_interval == 0 and self.use_eval:
                self.eval(total_num_steps)
    
    def warmup(self):
        # reset env
        obs = self.envs.reset()

        # replay buffer
        if self.use_centralized_V:
            share_obs = obs.reshape(self.n_rollout_threads, -1)
            share_obs = np.expand_dims(share_obs, 1).repeat(1, axis=1)
        else:
            share_obs = obs

        self.buffer.share_obs[0] = share_obs.copy()
        self.buffer.obs[0] = obs.copy()
    
    @torch.no_grad()
    def collect_actions(self, step):
        self.trainer.prep_rollout()

        # Collect actions
        latent_actions, sampled_actions, action, action_log_prob, action_log_prob_last, rnn_states, noise = \
            self.trainer.policy.get_actions(np.concatenate(self.buffer.obs[step]),
                            np.concatenate(self.buffer.rnn_states[step]),
                            np.concatenate(self.buffer.masks[step]),
                            return_noise=True
                            )
        
        # Split results back into per-thread arrays
        latent_actions = np.array(np.split(_t2n(latent_actions), self.n_rollout_threads))
        sampled_actions = np.array(np.split(_t2n(sampled_actions), self.n_rollout_threads))
        actions = np.array(np.split(_t2n(action), self.n_rollout_threads))
        action_log_probs_last = np.array(np.split(_t2n(action_log_prob_last), self.n_rollout_threads))
        action_log_probs = np.array(np.split(_t2n(action_log_prob), self.n_rollout_threads))
        rnn_states = np.array(np.split(_t2n(rnn_states), self.n_rollout_threads))
        noise = np.array(np.split(_t2n(noise), self.n_rollout_threads))
        
        # Actions for env
        actions_env = actions

        return latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, actions_env, noise

    def eval(self, total_steps):
        episodes = self.all_args.eval_episodes
        act_step = self.all_args.act_step
        env_infos = {
            "episode_length": [],
            "episode_return": [],
        }

        for episode in range(episodes):
            obs = self.eval_envs.reset()
            rnn_states = np.zeros((self.n_eval_rollout_threads, self.num_agents, self.recurrent_N, self.hidden_size), dtype=np.float32)
            rnn_states_critic = np.zeros((self.n_eval_rollout_threads, self.num_agents, *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)
            masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)

            start = time.time()

            for step in range(0, self.all_args.eval_episode_length, act_step):
                # Sample actions for next act_step steps at once
                _, _, actions, _, _, rnn_states, _ = \
                    self.trainer.policy.get_actions(np.concatenate(obs),
                                    np.concatenate(rnn_states),
                                    np.concatenate(masks),
                                    deterministic=True,
                                    return_noise=True
                                )
                
                # debug_print(actions.shape)
                actions_env = np.array(np.split(_t2n(actions), self.n_eval_rollout_threads))
                # debug_print(actions.shape)
                rnn_states = np.array(np.split(_t2n(rnn_states), self.n_eval_rollout_threads))

                # debug_print(actions_env.shape)

                # Execute actions one by one
                for i in range(min(act_step, self.all_args.eval_episode_length - step)):
                    obs, rewards, dones, infos = self.eval_envs.step(actions_env[:, :, i])

                    rnn_states[dones == True] = np.zeros(((dones == True).sum(), self.recurrent_N, self.hidden_size), dtype=np.float32)
                    masks = np.ones((self.n_eval_rollout_threads, self.num_agents, 1), dtype=np.float32)
                    masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)

                    if self.all_args.norm_obs:
                        obs = np.clip((obs - self.trainer.policy.obs_rms.mean) / 
                                    np.sqrt(self.trainer.policy.obs_rms.var + 1e-6), 
                                    -self.all_args.clip_obs, self.all_args.clip_obs)

                    for info in infos:
                        for k in env_infos.keys():
                            if k in info[0]:
                                env_infos[k].append(info[0][k])
            
            total_num_steps = self.all_args.eval_episode_length * self.n_eval_rollout_threads
            
            print("Step", total_num_steps)
            end = time.time()
            print("\n Scenario {} Algo {} Exp {} updates {}/{} episodes, total num timesteps {}/{}, FPS {}.\n"
                    .format(self.all_args.scenario_name,
                            self.algorithm_name,
                            self.experiment_name,
                            episode,
                            episodes,
                            total_num_steps,
                            self.num_env_steps,
                            int(total_num_steps / (end - start))))

        env_infos = {f"eval-{k}": v for k, v in env_infos.items()}
        self.log_env(env_infos, total_steps)
        pprint({k: np.mean(v) for k, v in env_infos.items() if len(v) > 0})
        
        return np.mean(env_infos["eval-episode_return"])


    @torch.no_grad()
    def collect_critic(self, step):
        # Collect critic values
        value, rnn_states_critic = self.trainer.policy.critic_forward(
            np.concatenate(self.buffer.share_obs[step]),
            np.concatenate(self.buffer.rnn_states_critic[step]),
            None,
            np.concatenate(self.buffer.masks[step])
        )

        # Split results back into per-thread arrays
        values = np.array(np.split(_t2n(value), self.n_rollout_threads))[:, 0]
        rnn_states_critic = np.array(np.split(_t2n(rnn_states_critic), self.n_rollout_threads))

        return values, rnn_states_critic

    def insert(self, data):
        # obs, rewards, dones, infos, values, latent_actions, sampled_actions, actions, action_log_probs, rnn_states, rnn_states_critic = data
        obs, rewards, dones, infos, values, latent_actions, sampled_actions, actions, action_log_probs, action_log_probs_last, rnn_states, rnn_states_critic, noise = data
        
        rnn_states[dones == True] = np.zeros(((dones == True).sum(), self.recurrent_N, self.hidden_size), dtype=np.float32)
        rnn_states_critic[dones == True] = np.zeros(((dones == True).sum(), *self.buffer.rnn_states_critic.shape[3:]), dtype=np.float32)
        masks = np.ones((self.n_rollout_threads, 1, 1), dtype=np.float32)
        masks[dones == True] = np.zeros(((dones == True).sum(), 1), dtype=np.float32)

        if self.use_centralized_V:
            share_obs = obs.reshape(self.n_rollout_threads, -1)
            share_obs = np.expand_dims(share_obs, 1).repeat(1, axis=1)
        else:
            share_obs = obs


        self.buffer.insert(
            obs,
            obs,
            obs,
            rnn_states,
            rnn_states_critic,
            actions=actions,
            action_log_probs=action_log_probs,
            action_log_probs_last=action_log_probs_last,
            value_preds=values,
            rewards=rewards,
            masks=masks,
            latent_actions=latent_actions,
            sampled_actions=sampled_actions,
            noise=noise
        )

        for i in range(self.n_rollout_threads):
            self.rewards_track[i] += rewards[i, 0, 0]
            self.lengths_track[i] += 1
            if dones[i, 0]:
                self.current_finish += 1
                self.current_reward_sum += self.rewards_track[i]
                self.current_length_sum += self.lengths_track[i]
                self.rewards_track[i] = 0
                self.lengths_track[i] = 0

        # self.buffer.insert(share_obs, obs, rnn_states, rnn_states_critic, actions, action_log_probs, values, rewards, masks, latent_actions=latent_actions, sampled_actions=sampled_actions)

    @torch.no_grad()
    def compute(self):
        result = dict()


        self.trainer.prep_rollout()
        next_values = self.trainer.policy.get_values(
                            np.concatenate(self.buffer.share_obs[-1]),
                            None,
                            np.concatenate(self.buffer.rnn_states_critic[-1]),
                            np.concatenate(self.buffer.masks[-1]))
        next_values = np.array(np.split(_t2n(next_values), self.n_rollout_threads))[:, 0]
        compute_returns(self.buffer, next_values, self.trainer.value_normalizer)
        # if self.all_args.norm_reward:
        #     self.buffer.returns = np.clip((self.buffer.returns) / np.sqrt(self.trainer.policy.ret_rms.var + 1e-6), -self.all_args.clip_reward, self.all_args.clip_reward)
        #     # debug_print(self.trainer.policy.ret_rms.var, self.trainer.policy.ret_rms.mean, self.trainer.policy.ret_rms.var.shape)
        #     self.trainer.policy.ret_rms.update(self.buffer.returns)

        return result

    # def save(self, episode=0):
    #     """Save policy's actor and critic networks."""
    #     if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
    #         self.policy.save(self.save_dir, episode)
    #     else:
    #         if hasattr(self.trainer.policy, 'actor'):
    #             policy_actor = self.trainer.policy.actor
    #             torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor.pt")
    #             torch.save(policy_actor.state_dict(), str(self.save_dir) + "/actor_{}.pt".format(episode))
    #         if hasattr(self.trainer.policy, 'critic'):
    #             policy_critic = self.trainer.policy.critic
    #             torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic.pt")
    #             torch.save(policy_critic.state_dict(), str(self.save_dir) + "/critic_{}.pt".format(episode))

    # def restore(self, model_dir):
    #     """Restore policy's networks from a saved model."""
    #     if self.algorithm_name == "mat" or self.algorithm_name == "mat_dec":
    #         self.policy.restore(model_dir)
    #     else:
    #         policy_actor_state_dict = torch.load(str(self.model_dir) + '/actor.pt')
    #         self.policy.actor.load_state_dict(policy_actor_state_dict)
    #         if not self.all_args.use_render:
    #             policy_critic_state_dict = torch.load(str(self.model_dir) + '/critic.pt')
    #             self.policy.critic.load_state_dict(policy_critic_state_dict)