# from email import policy
from json import load
import tarfile
import torch
import torch.nn as nn
import torch.nn.functional as F
from copy import deepcopy
import numpy as np
from collections import OrderedDict
import os
from torch.distributions import kl, Normal, kl_divergence

from offlinerl.utils.env import get_env
from offlinerl.utils.net.recurrent import RecurrentGRU
from offlinerl.utils.net.maple_net import GaussianOutputHead, ValueHead, GaussianPolicyNetwork, ValueNetwork
from offlinerl.utils.simulator import save_env, reset_env, update_config_files
from offlinerl.utils.net.transformer import Transformer

import gym.spaces.discrete as discrete

class SACAgent(nn.Module):
    def __init__(self, args):
        super(SACAgent, self).__init__()
        self.args = args
        self.actor = GaussianPolicyNetwork(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            policy_hidden_dims = args["policy_hidden_dims"]
        ).to(args["device"])
        
        self.q1 = ValueNetwork(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            value_hidden_dims = args["value_hidden_dims"]
        ).to(args["device"])
        
        self.q2 = ValueNetwork(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            value_hidden_dims = args["value_hidden_dims"]
        ).to(args["device"])
        
        self.target_q1 = deepcopy(self.q1)
        self.target_q2 = deepcopy(self.q2)
        self.target_q1.requires_grad_(False)
        self.target_q2.requires_grad_(False)
        
        self.log_alpha = torch.zeros(1, requires_grad=True, device=args["device"])
        self.log_alpha = nn.Parameter(self.log_alpha, requires_grad=True)
        self.actor_optim = torch.optim.Adam(self.actor.parameters(), lr=args["actor_lr"])
        self.critic_optim = torch.optim.Adam([*self.q1.parameters(), *self.q2.parameters()], lr=args["critic_lr"])
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=args["actor_lr"])

        self.rew_max = args["rew_max"]
        self.rew_min = args["rew_min"]
        self.discount = args["discount"]
        self.device = args["device"]
        self.logger = args["logger"]
        
    def evaluate_action(self, state, action):
        mu, log_std = self.actor(state)
        # action_prev_tanh = torch.atanh(action)
        # action_prev_tanh = torch.log((1+action+1e-5) / (1-action+1e-5)) / 2 

        action = torch.clamp(action, -1.0+1e-6, 1.0-1e-6)
        action_prev_tanh = 0.5 * (action.log1p() - (-action).log1p())
        log_prob = torch.distributions.Normal(mu, log_std.exp()).log_prob(action_prev_tanh+1e-6).sum(dim=-1)
        log_prob -= torch.sum(2*(np.log(2) - action_prev_tanh - torch.nn.functional.softplus(-2*action_prev_tanh)), dim=-1)
        return log_prob
        
    def get_action(self, state, deterministic=False, out_mean_std=False):
        action, log_prob, mu, logstd = self.actor.sample(state, deterministic=deterministic)
        return [action, log_prob] if not out_mean_std else [action, log_prob, mu, logstd]
    
    def train_policy(self, batch, behavior_cloning=False):
        # rewards = torch.from_numpy(batch.rew).to(self.device)
        # terminals = torch.from_numpy(batch.done).to(self.device)
        # obs = torch.from_numpy(batch.obs).to(self.device)
        # actions = torch.from_numpy(batch.act).to(self.device)
        # next_obs = torch.from_numpy(batch.obs_next).to(self.device)
        
        try:
            rewards = torch.from_numpy(batch.rew).to(self.device)
            terminals = torch.from_numpy(batch.done).to(self.device)
            obs = torch.from_numpy(batch.obs).to(self.device)
            actions = torch.from_numpy(batch.act).to(self.device)
            next_obs = torch.from_numpy(batch.obs_next).to(self.device)
        except AttributeError:
            rewards = torch.from_numpy(batch["rewards"]).to(self.device)
            terminals = torch.from_numpy(batch["terminals"]).to(self.device)
            obs = torch.from_numpy(batch["observations"]).to(self.device)
            actions = torch.from_numpy(batch["actions"]).to(self.device)
            next_obs = torch.from_numpy(batch["next_observations"]).to(self.device)
        
        with torch.no_grad():
            next_action, next_log_prob, _, __ = self.actor.sample(next_obs)
            next_q1 = self.target_q1(next_obs, next_action)
            next_q2 = self.target_q2(next_obs, next_action)
            next_q = torch.min(next_q1, next_q2) - self.log_alpha.exp() * torch.unsqueeze(next_log_prob, dim=-1)
            q_target = rewards + self.discount * (~terminals) * next_q
            if self.args["q_target_clip"]:
                q_target = torch.clip(
                    q_target, 
                    self.rew_min / (1-self.discount),
                    self.rew_max / (1-self.discount)
                )
        q1 = self.q1(obs, actions)
        q2 = self.q2(obs, actions)
        q_loss = torch.nn.functional.mse_loss(q1, q_target) + torch.nn.functional.mse_loss(q2, q_target)
        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()
        
        self._soft_update(self.target_q1, self.q1)
        self._soft_update(self.target_q2, self.q2)
        
        new_actions, new_log_prob, _, __ = self.actor.sample(obs)
        new_log_prob = torch.unsqueeze(new_log_prob, dim=-1)
        new_q1 = self.q1(obs, new_actions)
        new_q2 = self.q2(obs, new_actions)
        new_q = torch.min(new_q1, new_q2)
        
        if self.args["learnable_alpha"]:
            alpha_loss = - (self.log_alpha * (new_log_prob+ \
                                                         self.args['target_entropy']).detach()).mean()
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
            
        if behavior_cloning:
            sac_loss = (self.log_alpha.exp().detach() * new_log_prob - new_q).mean()
            sac_loss_scaled = 2.5*sac_loss / new_q.detach().mean()
            bc_loss = (actions - new_actions) ** 2
            bc_loss = bc_loss.sum(-1).mean()
            policy_loss = sac_loss_scaled + bc_loss*self.args["BC"]["bc_loss_coeff"]
        else:
            sac_loss = (self.log_alpha.exp().detach() * new_log_prob - new_q).mean()
            policy_loss = sac_loss
            
        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()
        
        ret = dict()
        ret["min_q"] = new_q.detach().cpu().mean().numpy()
        ret["q_loss"] = q_loss.detach().cpu().numpy()
        if self.args["learnable_alpha"]:
            ret["alpha_loss"] = alpha_loss.detach().cpu().numpy()
        if behavior_cloning:
            ret["policy_loss"] = policy_loss.detach().cpu().numpy()
            ret["sac_loss"] = sac_loss.detach().cpu().numpy()
            ret["bc_loss"] = bc_loss.detach().cpu().numpy()
        else:
            ret["policy_loss"] = policy_loss.detach().cpu().numpy()
        return ret 
            
    def _soft_update(self, net_target, net, soft_target_tau=5e-3):
        for o, n in zip(net_target.parameters(), net.parameters()):
            o.data.copy_(o.data*(1.0-soft_target_tau) + n.data*soft_target_tau) 

    def eval_on_real_env(self):
        env = get_env(self.args["task"])
        results = ([self.test_one_trail() for _ in range(self.args["eval_runs"])])
        rewards = [result[0] for result in results]
        episode_lengths = [result[1] for result in results]
        rew_mean = np.mean(rewards)
        len_mean = np.mean(episode_lengths)

        res = OrderedDict()
        res["Reward_Mean_Env"] = rew_mean
        try:
            res["Score"] = env.get_normalized_score(rew_mean)
        except:
            print("no data")
        res["Length_Mean_Env"] = len_mean

        return res
    
    def test_one_trail(self):
        rewards = lengths = 0
        env = get_env(self.args["task"])
        with torch.no_grad():
            state, done = env.reset(), False
            while not done:
                state = torch.from_numpy(state[None, :]).float().to(self.device)
                action, _ = self.get_action(state, deterministic=True)
                action_use = torch.squeeze(action).cpu().numpy()
                if type(env.action_space) == discrete.Discrete:
                    action_use = np.argmax(action_use)
                state, reward, done, _ = env.step(action_use)
                rewards += reward
                lengths += 1
        
        return (rewards, lengths)
                
        
    def save(self, save_path):
        assert save_path, "save path cannot be None!"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(self.actor.state_dict(), os.path.join(save_path, "actor.pt"))
        torch.save(self.q1.state_dict(), os.path.join(save_path, "q1.pt"))
        torch.save(self.q2.state_dict(), os.path.join(save_path, "q2.pt"))
        torch.save(self.log_alpha.data, os.path.join(save_path, "log_alpha.pt"))
        torch.save(self.target_q1.state_dict(), os.path.join(save_path, "target_q1.pt"))
        torch.save(self.target_q2.state_dict(), os.path.join(save_path, "target_q2.pt"))
        torch.save(self.actor_optim.state_dict(), os.path.join(save_path, "actor_optim.pt"))
        torch.save(self.critic_optim.state_dict(), os.path.join(save_path, "critic_optim.pt"))
        torch.save(self.alpha_optim.state_dict(), os.path.join(save_path, "alpha_optim.pt"))

    def load(self, load_path):
        assert load_path, "load path cannot be None!"
        self.actor.load_state_dict(torch.load(os.path.join(load_path, "actor.pt"), map_location=self.device))
        self.q1.load_state_dict(torch.load(os.path.join(load_path, "q1.pt"), map_location=self.device))
        self.q2.load_state_dict(torch.load(os.path.join(load_path, "q2.pt"), map_location=self.device))
        self.target_q1.load_state_dict(torch.load(os.path.join(load_path, "target_q1.pt"), map_location=self.device))
        self.target_q2.load_state_dict(torch.load(os.path.join(load_path, "target_q2.pt"), map_location=self.device))
        self.log_alpha.data = torch.load(os.path.join(load_path, "log_alpha.pt"), map_location=self.device)
        self.actor_optim.load_state_dict(torch.load(os.path.join(load_path, "actor_optim.pt"), map_location=self.device))
        self.critic_optim.load_state_dict(torch.load(os.path.join(load_path, "critic_optim.pt"), map_location=self.device))
        self.alpha_optim.load_state_dict(torch.load(os.path.join(load_path, "alpha_optim.pt"), map_location=self.device))

        
    def state_dict(self):
        return {
            "actor": self.actor.state_dict(),
            "q1": self.q1.state_dict(),
            "q2": self.q2.state_dict(),
            "target_q1": self.target_q1.state_dict(),
            "target_q2": self.target_q2.state_dict(),
            "log_alpha": self.log_alpha.data,
            "actor_optim": self.actor_optim.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "alpha_optim": self.alpha_optim.state_dict(),
        }
        
    def load_state_dict(self, state_dict):
        self.actor.load_state_dict(state_dict["actor"])
        self.q1.load_state_dict(state_dict["q1"])
        self.q2.load_state_dict(state_dict["q2"])
        self.target_q1.load_state_dict(state_dict["target_q1"])
        self.target_q2.load_state_dict(state_dict["target_q2"])
        self.log_alpha.data = state_dict["log_alpha"]
        self.actor_optim.load_state_dict(state_dict["actor_optim"])
        self.critic_optim.load_state_dict(state_dict["critic_optim"])
        self.alpha_optim.load_state_dict(state_dict["alpha_optim"])

class RNNSACAgent(nn.Module):
    def __init__(self, args):
        super(RNNSACAgent, self).__init__()
        self.args = args
        self.policy_gru = RecurrentGRU(
            input_dim = args["obs_shape"] + args["action_shape"], 
            device = args["device"], 
            rnn_hidden_dim = args["rnn_hidden_dim"], 
            rnn_layer_num = args["rnn_layer_num"]
        ).to(args["device"])
        self.value_gru = RecurrentGRU(
            input_dim = args["obs_shape"] + args["action_shape"], 
            device = args["device"], 
            rnn_hidden_dim = args["rnn_hidden_dim"], 
            rnn_layer_num = args["rnn_layer_num"]
        ).to(args["device"])
        self.actor = GaussianOutputHead(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            embedding_dim = args["rnn_hidden_dim"], decoder_hidden_dims = args["decoder_hidden_dims"], 
            head_hidden_dims = args["head_hidden_dims"]
        ).to(args["device"])
        self.q1 = ValueHead(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            embedding_dim = args["rnn_hidden_dim"], decoder_hidden_dims = args["decoder_hidden_dims"], 
            head_hidden_dims = args["head_hidden_dims"]
        ).to(args["device"])
        self.q2 = ValueHead(
            obs_dim = args["obs_shape"], action_dim = args["action_shape"], 
            embedding_dim = args["rnn_hidden_dim"], decoder_hidden_dims = args["decoder_hidden_dims"], 
            head_hidden_dims = args["head_hidden_dims"]
        ).to(args["device"])
        
        self.target_q1 = deepcopy(self.q1)
        self.target_q2 = deepcopy(self.q2)
        self.target_q1.requires_grad_(False)
        self.target_q2.requires_grad_(False)

        self.log_alpha = torch.zeros(1, requires_grad=True, device=args["device"])
        self.log_alpha = nn.Parameter(self.log_alpha, requires_grad=True)

        self.actor_optim = torch.optim.Adam([*self.policy_gru.parameters(), *self.actor.parameters()], lr=args["actor_lr"])
        self.critic_optim = torch.optim.Adam([*self.value_gru.parameters(), *self.q1.parameters(), *self.q2.parameters()], lr=args["critic_lr"])
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=args["actor_lr"])

        self.rew_max = args["rew_max"]
        self.rew_min = args["rew_min"]
        self.discount = args["discount"]

        self.device = args["device"]
        self.logger = args["logger"]
    
    
    def get_action(self, state, lst_action, hidden, deterministic=False, out_mean_std=False):
        if len(state.shape) == 2:
            state = torch.unsqueeze(state, 1)
        if len(lst_action.shape) == 2:
            lst_action = torch.unsqueeze(lst_action, 1)
        if len(hidden.shape) == 2:
            hidden = torch.unsqueeze(hidden, 0)
            
        # only take the first element
        lens = [1] * state.shape[0]
        rnn_input_pair = torch.cat([state, lst_action], dim=-1)
        policy_embedding, next_hidden = self.policy_gru(rnn_input_pair, lens, hidden)
        action, log_prob, mu, logstd = self.actor.sample(state, policy_embedding, deterministic=deterministic)
        
        action, log_prob, mu, logstd = torch.squeeze(action, 1), torch.squeeze(log_prob, 1) if log_prob is not None else None, torch.squeeze(mu, 1), torch.squeeze(logstd, 1)
        
        ret = []
        ret += [action, log_prob]     # here we don't need to worry about deterministic, cause if deterministic, then actor.sample will return torch.tanh(mu) and None as action and log_prob
        ret += ([mu, logstd] if out_mean_std else [])
        ret += [next_hidden]
        return ret

    def reset(self):
        return None
        
    def get_value(self, state, action, lst_action, hidden):
        if len(state.shape) == 2:
            state = torch.unsqueeze(state, dim=1)
        if len(action.shape) == 2:
            action = torch.unsqueeze(action, dim=1)
        if len(lst_action.shape) == 2:
            lst_action = torch.unsqueeze(lst_action, dim=1)
        if len(hidden.shape) == 2:
            hidden = torch.unsqueeze(hidden, dim=0)
        
        lens = [1] * state.shape[0]
        rnn_input_pair = torch.cat([state, lst_action], dim=-1)
        value_embedding, next_hidden = self.value_gru(rnn_input_pair, lens, hidden)
        value_q1 = self.q1(state, action, value_embedding)
        value_q2 = self.q2(state, action, value_embedding)
        value_min = torch.min(value_q1, value_q2)
        
        value_min = torch.squeeze(value_min, 1)
        return value_min, next_hidden

    def train_policy(self, batch, behavior_cloning=False, sac_embedding_infer="concat", cons_batch=None, cons_policy=None):
        # we use 3D slices to train the policy
        batch['valid'] = batch['valid'].astype(int)
        lens = np.sum(batch['valid'], axis=1).squeeze(-1)
        max_len = np.max(lens)
        for k in batch:
            batch[k] = torch.from_numpy(batch[k][:,:max_len]).to(self.device)

        if cons_batch is not None:
            cons_batch['valid'] = cons_batch['valid'].astype(int)
            cons_lens = np.sum(cons_batch['valid'], axis=1).squeeze(-1)
            cons_max_len = np.max(cons_lens)
            for k in cons_batch:
                cons_batch[k] = torch.from_numpy(cons_batch[k][:, :cons_max_len]).to(self.device)
        
        if sac_embedding_infer == "concat":
            value_hidden = batch["value_hidden"][:, 0]
            policy_hidden = batch["policy_hidden"][:, 0]
            value_embedding, value_hidden_next = self.value_gru(torch.cat([batch["observations"], batch["last_actions"]], dim=-1), lens, pre_hidden=value_hidden)
            policy_embedding, policy_hidden_next = self.policy_gru(torch.cat([batch["observations"], batch["last_actions"]], dim=-1), lens, pre_hidden=policy_hidden)

            lens_next = torch.ones(len(lens)).int()
            value_embedding_next, _ = self.value_gru(torch.cat([batch["next_observations"][:, -1:], batch["actions"][:, -1:]], dim=-1), lens_next, pre_hidden=value_hidden_next)
            policy_embedding_next, _ = self.policy_gru(torch.cat([batch["next_observations"][:, -1:], batch["actions"][:, -1:]], dim=-1), lens_next, pre_hidden=policy_hidden_next)
            value_embedding_next = torch.cat([value_embedding[:, 1:], value_embedding_next], dim=1)
            policy_embedding_next = torch.cat([policy_embedding[:, 1:], policy_embedding_next], dim=1)
        elif sac_embedding_infer == "direct":
            _observations = torch.cat([batch["observations"], batch["next_observations"][:, -1:]], dim=1)
            _last_actions = torch.cat([batch["last_actions"], batch["actions"][:, -1:]], dim=1)
            _value_embedding, _ = self.value_gru(torch.cat([_observations, _last_actions], dim=-1), lens+1, pre_hidden=batch["value_hidden"][:, 0])
            _policy_embedding, _ = self.policy_gru(torch.cat([_observations, _last_actions], dim=-1), lens+1, pre_hidden=batch["policy_hidden"][:, 0])
            value_embedding = _value_embedding[:, :-1]
            policy_embedding = _policy_embedding[:, :-1]
            value_embedding_next = _value_embedding[:, 1:]
            policy_embedding_next = _policy_embedding[:, 1:]
        
        with torch.no_grad():
            action_target, log_prob_target, mu_target, logstd_target = self.actor.sample(batch["next_observations"], policy_embedding_next)
            q1_target = self.target_q1(batch["next_observations"], action_target, value_embedding_next)
            q2_target = self.target_q2(batch["next_observations"], action_target, value_embedding_next)
            q_target = torch.min(q1_target, q2_target)
            q_target = q_target - self.log_alpha.exp() * torch.unsqueeze(log_prob_target, dim=-1)
            q_target = batch["rewards"] + self.discount * (~batch["terminals"]) * (q_target)
            if self.args["q_target_clip"]:
                q_target = torch.clip(q_target, 
                                    self.rew_min / (1-self.discount),
                                    self.rew_max / (1-self.discount)
                )
        
        # update critic
        q1 = self.q1(batch["observations"], batch["actions"], value_embedding)
        q2 = self.q2(batch["observations"], batch["actions"], value_embedding)
        valid_num = torch.sum(batch["valid"])
        
        q1_loss = torch.sum(((q1-q_target)**2)*batch['valid'])/valid_num
        q2_loss = torch.sum(((q2-q_target)**2)*batch['valid'])/valid_num
        q_loss = (q1_loss+q2_loss)
        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()
        
        self._soft_update(self.target_q1, self.q1, soft_target_tau = self.args["soft_target_tau"])
        self._soft_update(self.target_q2, self.q2, soft_target_tau = self.args["soft_target_tau"])
        
        # update alpha and actor
        actions, log_prob, mu, logstd = self.actor.sample(batch["observations"], policy_embedding)
        log_prob = log_prob.unsqueeze(dim=-1)     # (B, T, 1)
        if cons_policy is None and cons_batch is None:
            cons_loss = 0
        else:
            cons_policy_hidden = cons_batch["policy_hidden"][:, 0]
            cons_policy_embedding, _ = self.policy_gru(
                torch.cat([cons_batch["observations"], cons_batch["last_actions"]], dim=-1), lens, pre_hidden=cons_policy_hidden)

            with torch.no_grad():
                cons_dist = cons_policy.actor.inference_dist(cons_batch["observations"])
            dist = self.actor.inference_dist(cons_batch['observations'], cons_policy_embedding)
            cons_loss = (kl_divergence(dist, cons_dist) * cons_batch["valid"]).sum() / valid_num

        # if cons_policy is None:
        #     cons_loss = 0
        # else:
        #     with torch.no_grad():
        #         actions_cons, _, _, _ = cons_policy.actor.sample(batch["observations"], policy_embedding, deterministic=True)
        #     cons_loss = F.mse_loss(actions, actions_cons)

        if self.args["learnable_alpha"]:
            alpha_loss = - torch.sum(self.log_alpha * ((log_prob+ \
                                                         self.args['target_entropy'])*batch['valid']).detach())/valid_num
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()
        
        q1_ = self.q1(batch["observations"], actions, value_embedding.detach())
        q2_ = self.q2(batch["observations"], actions, value_embedding.detach())
        min_q_ = torch.min(q1_, q2_)
        
        if behavior_cloning:
            sac_loss = self.log_alpha.exp().detach() * log_prob - min_q_
            sac_loss = torch.sum(sac_loss * batch["valid"]) / valid_num
            sac_loss_scaled = 2.5*sac_loss / min_q_.detach().mean()
            bc_loss = (batch["actions"] - actions) ** 2
            bc_loss = torch.sum(bc_loss * batch["valid"]) / valid_num
            policy_loss = sac_loss_scaled + bc_loss*self.args["BC"]["bc_loss_coeff"]
        else:
            sac_loss = self.log_alpha.exp().detach() * log_prob - min_q_
            sac_loss = torch.sum(sac_loss * batch["valid"]) / valid_num
            policy_loss = sac_loss

        coef_c = self.args["cc"]
        policy_loss += coef_c * cons_loss
            
        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()
        
        ret = dict()
        ret["min_q"] = min_q_.detach().cpu().mean().numpy()
        ret["q_loss"] = q_loss.detach().cpu().numpy()
        if self.args["learnable_alpha"]:
            ret["alpha_loss"] = alpha_loss.detach().cpu().numpy()
        if behavior_cloning:
            ret["policy_loss"] = policy_loss.detach().cpu().numpy()
            ret["sac_loss"] = sac_loss.detach().cpu().numpy()
            ret["bc_loss"] = bc_loss.detach().cpu().numpy()
        else:
            if cons_policy is not None:
                ret["cons_loss"] = cons_loss.detach().cpu().numpy()
                ret["policy_loss"] = policy_loss.detach().cpu().numpy() - coef_c * ret["cons_loss"]
            else:
                ret["policy_loss"] = policy_loss.detach().cpu().numpy()
        return ret
        
    def _soft_update(self, net_target, net, soft_target_tau=5e-3):
        for o, n in zip(net_target.parameters(), net.parameters()):
            o.data.copy_(o.data*(1.0-soft_target_tau) + n.data*soft_target_tau)       
  
    def eval_policy(self, env):
        # env = get_env(self.args["task"])
        # res = self.test_on_real_env(self.args["eval_runs"], env)
        res = self.eval_on_real_env(env)
        return res

    def eval_on_real_env(self, env):
        env = get_env(self.args["task"]) if env is None else env
        results = ([self.test_one_trail(env) for _ in range(self.args["eval_runs"])])
        rewards = [result[0] for result in results]
        episode_lengths = [result[1] for result in results]
        rew_mean = np.mean(rewards)
        len_mean = np.mean(episode_lengths)

        res = OrderedDict()
        res["Reward_Mean_Env"] = rew_mean
        try:
            res["Score"] = env.get_normalized_score(rew_mean)
        except:
            print("no data")
        res["Length_Mean_Env"] = len_mean

        return res

    def test_one_trail(self, env):
        # env = get_env(self.args["task"])
        with torch.no_grad():
            state, done = env.reset(), False
            lst_action = torch.zeros((1,1,self.args['action_shape'])).to(self.device)
            hidden_policy = torch.zeros((1,1,self.args['rnn_hidden_dim'])).to(self.device)
            rewards = 0
            lengths = 0
            while not done:
                state = state[np.newaxis]               #这里增加了数据的维度，当做batch为1在处理
                state = torch.from_numpy(state).float().to(self.device)
                action, _, hidden_policy = self.get_action(state, lst_action, hidden_policy, deterministic=True)
                assert _ is None
                use_action = action.cpu().numpy().reshape(-1)
                if type(env.action_space) == discrete.Discrete:
                    use_action = np.argmax(use_action)
                state_next, reward, done, _ = env.step(use_action)
                lst_action = action
                state = state_next
                # action = policy.get_action(state).reshape(-1)
                # state, reward, done, _ = env.step(action)
                rewards += reward
                lengths += 1
        return (rewards, lengths)

    def eval_on_sim_env(self, degree=[0.5, 0.75, 1.5, 2.0], type=None):
        print("[DEBUG]: test on sim env")
        save_env(self.args["task"])
        if isinstance(degree, float):
            degree = [degree]
        tested = dict()
        # 在原本的环境上测试
        env = get_env(self.args["task"])
        tested["initial"] = self.eval_on_real_env(number_runs=5, env=env)

        # 在diverse任务上进行测试
        if type is None:
            type = ["fric", "grav", "dens"]
        elif isinstance(type, str):
            type = [type]
        for t in type:
            for d in degree:
                info = t+"-"+str(d)
                print(f"[DEBUG]: test on sim env ${info}")
                update_config_files(self.args["task"], t, d)
                env = get_env(self.args["task"])
                tested[info] = self.test_on_real_env(number_runs=5, env=env)
        reset_env(self.args["task"])
        return tested

    
    def save(self, save_path):
        assert save_path, "save path cannot be None!"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(self.policy_gru.state_dict(), os.path.join(save_path, "policy_gru.pt"))
        torch.save(self.actor.state_dict(), os.path.join(save_path, "actor.pt"))
        torch.save(self.value_gru.state_dict(), os.path.join(save_path, "value_gru.pt"))
        torch.save(self.q1.state_dict(), os.path.join(save_path, "q1.pt"))
        torch.save(self.q2.state_dict(), os.path.join(save_path, "q2.pt"))
        torch.save(self.log_alpha.data, os.path.join(save_path, "log_alpha.pt"))
        torch.save(self.target_q1.state_dict(), os.path.join(save_path, "target_q1.pt"))
        torch.save(self.target_q2.state_dict(), os.path.join(save_path, "target_q2.pt"))
        torch.save(self.actor_optim.state_dict(), os.path.join(save_path, "actor_optim.pt"))
        torch.save(self.critic_optim.state_dict(), os.path.join(save_path, "critic_optim.pt"))
        torch.save(self.alpha_optim.state_dict(), os.path.join(save_path, "alpha_optim.pt"))

    def load(self, load_path):
        assert load_path, "load path cannot be None!"
        self.policy_gru.load_state_dict(torch.load(os.path.join(load_path, "policy_gru.pt"), map_location=self.device))
        self.actor.load_state_dict(torch.load(os.path.join(load_path, "actor.pt"), map_location=self.device))
        self.value_gru.load_state_dict(torch.load(os.path.join(load_path, "value_gru.pt"), map_location=self.device))
        self.q1.load_state_dict(torch.load(os.path.join(load_path, "q1.pt"), map_location=self.device))
        self.q2.load_state_dict(torch.load(os.path.join(load_path, "q2.pt"), map_location=self.device))
        self.target_q1.load_state_dict(torch.load(os.path.join(load_path, "target_q1.pt"), map_location=self.device))
        self.target_q2.load_state_dict(torch.load(os.path.join(load_path, "target_q2.pt"), map_location=self.device))
        self.log_alpha.data = torch.load(os.path.join(load_path, "log_alpha.pt"), map_location=self.device)
        self.actor_optim.load_state_dict(torch.load(os.path.join(load_path, "actor_optim.pt"), map_location=self.device))
        self.critic_optim.load_state_dict(torch.load(os.path.join(load_path, "critic_optim.pt"), map_location=self.device))
        self.alpha_optim.load_state_dict(torch.load(os.path.join(load_path, "alpha_optim.pt"), map_location=self.device))

        
    def state_dict(self):
        return {
            "policy_gru": self.policy_gru.state_dict(),
            "value_gru": self.value_gru.state_dict(),
            "actor": self.actor.state_dict(),
            "q1": self.q1.state_dict(),
            "q2": self.q2.state_dict(),
            "target_q1": self.target_q1.state_dict(),
            "target_q2": self.target_q2.state_dict(),
            "log_alpha": self.log_alpha.data,
            "actor_optim": self.actor_optim.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "alpha_optim": self.alpha_optim.state_dict(),
        }
        
    def load_state_dict(self, state_dict):
        self.policy_gru.load_state_dict(state_dict["policy_gru"])
        self.value_gru.load_state_dict(state_dict["value_gru"])
        self.actor.load_state_dict(state_dict["actor"])
        self.q1.load_state_dict(state_dict["q1"])
        self.q2.load_state_dict(state_dict["q2"])
        self.target_q1.load_state_dict(state_dict["target_q1"])
        self.target_q2.load_state_dict(state_dict["target_q2"])
        self.log_alpha.data = state_dict["log_alpha"]
        self.actor_optim.load_state_dict(state_dict["actor_optim"])
        self.critic_optim.load_state_dict(state_dict["critic_optim"])
        self.alpha_optim.load_state_dict(state_dict["alpha_optim"])


class TransformerSACAgent(nn.Module):
    def __init__(self, args):
        super(TransformerSACAgent, self).__init__()
        self.args = args
        # self.policy_gru = RecurrentGRU(
        #     input_dim=args["obs_shape"] + args["action_shape"],
        #     device=args["device"],
        #     rnn_hidden_dim=args["rnn_hidden_dim"],
        #     rnn_layer_num=args["rnn_layer_num"]
        # ).to(args["device"])
        # self.value_gru = RecurrentGRU(
        #     input_dim=args["obs_shape"] + args["action_shape"],
        #     device=args["device"],
        #     rnn_hidden_dim=args["rnn_hidden_dim"],
        #     rnn_layer_num=args["rnn_layer_num"]
        # ).to(args["device"])

        self.policy_gru = Transformer(input_dim=args["obs_shape"] * 2 + args["action_shape"],
                                      embed_dim=args["emb_hidden_dim"],
                                      heads=3, depth=2,
                                      output_dim=args["rnn_hidden_dim"]).to(args["device"])

        self.value_gru = Transformer(input_dim=args["obs_shape"] * 2 + args["action_shape"],
                                     embed_dim=args["emb_hidden_dim"],
                                     heads=3, depth=2,
                                     output_dim=args["rnn_hidden_dim"]).to(args["device"])
        self.policy_context = {
            'obs': [],
            'act': []
        }
        self.value_context = {
            'obs': [],
            'act': []
        }

        self.actor = GaussianOutputHead(
            obs_dim=args["obs_shape"], action_dim=args["action_shape"],
            embedding_dim=args["rnn_hidden_dim"], decoder_hidden_dims=args["decoder_hidden_dims"],
            head_hidden_dims=args["head_hidden_dims"]
        ).to(args["device"])
        self.q1 = ValueHead(
            obs_dim=args["obs_shape"], action_dim=args["action_shape"],
            embedding_dim=args["rnn_hidden_dim"], decoder_hidden_dims=args["decoder_hidden_dims"],
            head_hidden_dims=args["head_hidden_dims"]
        ).to(args["device"])
        self.q2 = ValueHead(
            obs_dim=args["obs_shape"], action_dim=args["action_shape"],
            embedding_dim=args["rnn_hidden_dim"], decoder_hidden_dims=args["decoder_hidden_dims"],
            head_hidden_dims=args["head_hidden_dims"]
        ).to(args["device"])

        self.target_q1 = deepcopy(self.q1)
        self.target_q2 = deepcopy(self.q2)
        self.target_q1.requires_grad_(False)
        self.target_q2.requires_grad_(False)

        self.log_alpha = torch.zeros(1, requires_grad=True, device=args["device"])
        self.log_alpha = nn.Parameter(self.log_alpha, requires_grad=True)

        self.actor_optim = torch.optim.Adam([*self.policy_gru.parameters(), *self.actor.parameters()],
                                            lr=args["actor_lr"])
        self.critic_optim = torch.optim.Adam(
            [*self.value_gru.parameters(), *self.q1.parameters(), *self.q2.parameters()], lr=args["critic_lr"])
        self.alpha_optim = torch.optim.Adam([self.log_alpha], lr=args["actor_lr"])

        self.rew_max = args["rew_max"]
        self.rew_min = args["rew_min"]
        self.discount = args["discount"]

        self.device = args["device"]
        self.logger = args["logger"]

    def reset(self):
        self.policy_context = {
            'obs': [],
            'act': [],
            'next_obs': [],
        }
        self.value_context = {
            'obs': [],
            'act': [],
            'next_obs': [],
        }

    def get_action(self, state, lst_action, hidden, deterministic=False, out_mean_std=False):
        if len(state.shape) == 2:
            state = torch.unsqueeze(state, 1)
        if len(lst_action.shape) == 2:
            lst_action = torch.unsqueeze(lst_action, 1)
        if len(hidden.shape) == 2:
            hidden = torch.unsqueeze(hidden, 0)

        lst_state = self.policy_context['next_obs'][-1] if len(self.policy_context['next_obs']) > 0 else torch.zeros_like(state)
        self.policy_context['obs'].append(lst_state)
        self.policy_context['act'].append(lst_action)
        self.policy_context['next_obs'].append(state)

        # only take the first element
        lens = [1] * state.shape[0]
        # rnn_input_pair = torch.cat([state, lst_action], dim=-1)
        rnn_input_pair = torch.cat([torch.cat(self.policy_context['obs'], dim=1),
                                    torch.cat(self.policy_context['act'], dim=1),
                                    torch.cat(self.policy_context['next_obs'], dim=1)], dim=-1)
        policy_embedding, next_hidden = self.policy_gru(rnn_input_pair)
        action, log_prob, mu, logstd = self.actor.sample(state, policy_embedding, deterministic=deterministic)

        action, log_prob, mu, logstd = torch.squeeze(action, 1), torch.squeeze(log_prob,
                                                                               1) if log_prob is not None else None, torch.squeeze(
            mu, 1), torch.squeeze(logstd, 1)

        ret = []
        ret += [action,
                log_prob]  # here we don't need to worry about deterministic, cause if deterministic, then actor.sample will return torch.tanh(mu) and None as action and log_prob
        ret += ([mu, logstd] if out_mean_std else [])
        ret += [next_hidden]
        return ret

    def get_value(self, state, action, lst_action, hidden):
        if len(state.shape) == 2:
            state = torch.unsqueeze(state, dim=1)
        if len(action.shape) == 2:
            action = torch.unsqueeze(action, dim=1)
        if len(lst_action.shape) == 2:
            lst_action = torch.unsqueeze(lst_action, dim=1)
        if len(hidden.shape) == 2:
            hidden = torch.unsqueeze(hidden, dim=0)

        lst_state = self.value_context['next_obs'][-1] if len(
            self.value_context['next_obs']) > 0 else torch.zeros_like(state)
        self.value_context['obs'].append(lst_state)
        self.value_context['act'].append(lst_action)
        self.value_context['next_obs'].append(state)
        # if len(self.value_context['obs']) > self.args["horizon"]:
        #     self.value_context['obs'].pop(0)
        #     self.value_context['act'].pop(0)

        lens = [1] * state.shape[0]
        rnn_input_pair = torch.cat([torch.cat(self.value_context['obs'], dim=1),
                                    torch.cat(self.value_context['act'], dim=1),
                                    torch.cat(self.value_context['next_obs'], dim=1)], dim=-1)
        # rnn_input_pair = torch.cat([state, lst_action], dim=-1)
        value_embedding, next_hidden = self.value_gru(rnn_input_pair)
        value_q1 = self.q1(state, action, value_embedding)
        value_q2 = self.q2(state, action, value_embedding)
        value_min = torch.min(value_q1, value_q2)

        value_min = torch.squeeze(value_min, 1)
        return value_min, next_hidden

    def train_policy(self, batch, behavior_cloning=False, sac_embedding_infer="concat", cons_policy=None, cons_batch=None):
        # we use 3D slices to train the policy
        batch['valid'] = batch['valid'].astype(int)
        lens = np.sum(batch['valid'], axis=1).squeeze(-1)
        max_len = np.max(lens)
        for k in batch:
            batch[k] = torch.from_numpy(batch[k][:, :max_len]).to(self.device)

        if cons_batch is not None:
            cons_batch['valid'] = cons_batch['valid'].astype(int)
            cons_lens = np.sum(cons_batch['valid'], axis=1).squeeze(-1)
            cons_max_len = np.max(cons_lens)
            for k in cons_batch:
                cons_batch[k] = torch.from_numpy(cons_batch[k][:, :cons_max_len]).to(self.device)

        value_embedding = []
        policy_embedding = []
        for t in range(max_len):
            obs = batch['observations'][:, :t] if t > 0 else torch.zeros_like(batch['observations'][:, :1])
            next_obs = batch['next_observations'][:, :t] if t > 0 else torch.zeros_like(batch['next_observations'][:, :1])
            act = batch['actions'][:, :t] if t > 0 else torch.zeros_like(batch['actions'][:, :1])
            inputs = torch.cat([obs, act, next_obs], dim=-1)
            ve, _ = self.value_gru(inputs)
            pe, _ = self.policy_gru(inputs)
            value_embedding.append(ve)
            policy_embedding.append(pe)
        value_embedding = torch.cat(value_embedding, dim=1)
        policy_embedding = torch.cat(policy_embedding, dim=1)
        value_embedding_next = torch.zeros_like(value_embedding)
        policy_embedding_next = torch.zeros_like(policy_embedding)
        value_embedding_next[:, :-1] = value_embedding[:, 1:]
        policy_embedding_next[:, :-1] = policy_embedding[:, 1:]

        with torch.no_grad():
            action_target, log_prob_target, mu_target, logstd_target = self.actor.sample(batch["next_observations"],
                                                                                         policy_embedding_next)
            q1_target = self.target_q1(batch["next_observations"], action_target, value_embedding_next)
            q2_target = self.target_q2(batch["next_observations"], action_target, value_embedding_next)
            q_target = torch.min(q1_target, q2_target)
            q_target = q_target - self.log_alpha.exp() * torch.unsqueeze(log_prob_target, dim=-1)
            q_target = batch["rewards"] + self.discount * (~batch["terminals"]) * (q_target)
            if self.args["q_target_clip"]:
                q_target = torch.clip(q_target,
                                      self.rew_min / (1 - self.discount),
                                      self.rew_max / (1 - self.discount)
                                      )

        # update critic
        q1 = self.q1(batch["observations"], batch["actions"], value_embedding)
        q2 = self.q2(batch["observations"], batch["actions"], value_embedding)
        valid_num = torch.sum(batch["valid"])

        q1_loss = torch.sum(((q1 - q_target) ** 2) * batch['valid']) / valid_num
        q2_loss = torch.sum(((q2 - q_target) ** 2) * batch['valid']) / valid_num
        q_loss = (q1_loss + q2_loss)
        self.critic_optim.zero_grad()
        q_loss.backward()
        self.critic_optim.step()

        self._soft_update(self.target_q1, self.q1, soft_target_tau=self.args["soft_target_tau"])
        self._soft_update(self.target_q2, self.q2, soft_target_tau=self.args["soft_target_tau"])

        # update alpha and actor
        actions, log_prob, mu, logstd = self.actor.sample(batch["observations"], policy_embedding)
        log_prob = log_prob.unsqueeze(dim=-1)  # (B, T, 1)
        if cons_policy is None:
            cons_loss = 0
        else:
            cons_policy_embedding = []
            for t in range(max_len):
                cpe, _ = self.policy_gru(torch.cat([batch["observations"][:, :t+1], batch["last_actions"][:, :t+1]], dim=-1))
                cons_policy_embedding.append(cpe)
            cons_policy_embedding = torch.cat(cons_policy_embedding, dim=1)
            with torch.no_grad():
                cons_dist = cons_policy.actor.inference_dist(cons_batch["observations"])
            dist = self.actor.inference_dist(cons_batch['observations'], cons_policy_embedding)
            cons_loss = (kl_divergence(dist, cons_dist) * cons_batch["valid"]).sum() / valid_num

        if self.args["learnable_alpha"]:
            alpha_loss = - torch.sum(self.log_alpha * ((log_prob + \
                                                        self.args['target_entropy']) * batch[
                                                           'valid']).detach()) / valid_num
            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

        q1_ = self.q1(batch["observations"], actions, value_embedding.detach())
        q2_ = self.q2(batch["observations"], actions, value_embedding.detach())
        min_q_ = torch.min(q1_, q2_)

        if behavior_cloning:
            sac_loss = self.log_alpha.exp().detach() * log_prob - min_q_
            sac_loss = torch.sum(sac_loss * batch["valid"]) / valid_num
            sac_loss_scaled = 2.5 * sac_loss / min_q_.detach().mean()
            bc_loss = (batch["actions"] - actions) ** 2
            bc_loss = torch.sum(bc_loss * batch["valid"]) / valid_num
            policy_loss = sac_loss_scaled + bc_loss * self.args["BC"]["bc_loss_coeff"]
        else:
            sac_loss = self.log_alpha.exp().detach() * log_prob - min_q_
            sac_loss = torch.sum(sac_loss * batch["valid"]) / valid_num
            policy_loss = sac_loss

        policy_loss += 0.1 * cons_loss

        self.actor_optim.zero_grad()
        policy_loss.backward()
        self.actor_optim.step()

        ret = dict()
        ret["min_q"] = min_q_.detach().cpu().mean().numpy()
        ret["q_loss"] = q_loss.detach().cpu().numpy()
        if self.args["learnable_alpha"]:
            ret["alpha_loss"] = alpha_loss.detach().cpu().numpy()
        if behavior_cloning:
            ret["cons_loss"] = cons_loss.detach().cpu().numpy()
            ret["policy_loss"] = policy_loss.detach().cpu().numpy()
            ret["sac_loss"] = sac_loss.detach().cpu().numpy()
            ret["bc_loss"] = bc_loss.detach().cpu().numpy()
        else:
            # ret["cons_loss"] = cons_loss.detach().cpu().numpy()
            ret["policy_loss"] = policy_loss.detach().cpu().numpy() #- 0.1 * ret["cons_loss"]
        return ret

    def _soft_update(self, net_target, net, soft_target_tau=5e-3):
        for o, n in zip(net_target.parameters(), net.parameters()):
            o.data.copy_(o.data * (1.0 - soft_target_tau) + n.data * soft_target_tau)

    def eval_policy(self, env=None):
        env = get_env(self.args["task"]) if env is None else env
        # res = self.test_on_real_env(self.args["eval_runs"], env)
        res = self.eval_on_real_env(env)
        return res

    def eval_on_real_env(self, env=None):
        env = get_env(self.args["task"]) if env is None else env
        results = ([self.test_one_trail(env) for _ in range(self.args["eval_runs"])])
        rewards = [result[0] for result in results]
        episode_lengths = [result[1] for result in results]
        rew_mean = np.mean(rewards)
        len_mean = np.mean(episode_lengths)

        res = OrderedDict()
        res["Reward_Mean_Env"] = rew_mean
        try:
            res["Score"] = env.get_normalized_score(rew_mean)
        except:
            print("no data")
        res["Length_Mean_Env"] = len_mean

        return res

    def test_one_trail(self, env=None):
        env = get_env(self.args["task"]) if env is None else env
        with torch.no_grad():
            state, done = env.reset(), False
            lst_action = torch.zeros((1, 1, self.args['action_shape'])).to(self.device)
            hidden_policy = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
            self.reset()
            rewards = 0
            lengths = 0
            while not done:
                state = state[np.newaxis]  # 这里增加了数据的维度，当做batch为1在处理
                state = torch.from_numpy(state).float().to(self.device)
                action, _, hidden_policy = self.get_action(state, lst_action, hidden_policy, deterministic=True)
                assert _ is None
                use_action = action.cpu().numpy().reshape(-1)
                if type(env.action_space) == discrete.Discrete:
                    use_action = np.argmax(use_action)
                state_next, reward, done, _ = env.step(use_action)
                lst_action = action
                state = state_next
                # action = policy.get_action(state).reshape(-1)
                # state, reward, done, _ = env.step(action)
                rewards += reward
                lengths += 1
        return (rewards, lengths)

    def sample_one_trail(self, env, pool, random=False):
        env = get_env(self.args["task"])
        with torch.no_grad():
            results = {}
            state, done = env.reset(), False
            self.reset()
            lst_action = torch.zeros((1, 1, self.args['action_shape'])).to(self.device)
            hidden_policy = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
            hidden_value = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
            rewards = 0
            lengths = 0
            state = state[np.newaxis]  # 这里增加了数据的维度，当做batch为1在处理
            state = torch.from_numpy(state).float().to(self.device)
            current_nonterm = np.ones((1), dtype=bool)
            while not done:
                action, _, hidden_policy_next = self.get_action(state, lst_action, hidden_policy, deterministic=False)
                if random:
                    action = self.args["action_space"].sample()
                    action = torch.from_numpy(action[np.newaxis]).float().to(self.device)
                _, hidden_value_next = self.get_value(state, action, lst_action, hidden_value)
                # assert _ is None
                use_action = action.cpu().numpy().reshape(-1)
                next_state, reward, done, _ = env.step(use_action)

                next_state = next_state[np.newaxis]
                next_state = torch.from_numpy(next_state).float().to(self.device)
                reward = np.array(reward)[np.newaxis, np.newaxis]
                done = np.array(done)[np.newaxis, np.newaxis]

                nonterm_mask = ~done.squeeze(-1)

                samples = {
                    "observations": state.cpu().numpy(),
                    "actions": action.cpu().numpy(),
                    "next_observations": next_state.cpu().numpy(),
                    "rewards": reward,
                    "terminals": done,
                    "last_actions": lst_action.cpu().numpy(),
                    "valid": current_nonterm.reshape(-1, 1),
                    "value_hidden": hidden_value.cpu().numpy(),
                    "policy_hidden": hidden_policy.cpu().numpy()
                }

                samples = {k: np.expand_dims(v, 1) for k, v in samples.items()}
                num_samples = 1
                assert num_samples == 1
                index = np.arange(
                    pool._pointer, pool._pointer + num_samples
                ) % pool._max_size
                for k in samples:
                    pool.fields[k][index, lengths] = samples[k][:, 0]

                current_nonterm = current_nonterm & nonterm_mask
                lst_action = action
                state = next_state
                hidden_policy = hidden_policy_next
                hidden_value = hidden_value_next

                rewards += reward
                lengths += 1

            pool._pointer += 1
            pool._pointer %= pool._max_size
            pool._size = min(pool._max_size, pool._size + num_samples)
        return (rewards, lengths)

    def sample_one_trail_(self, env, pool, random=False, log_transition=False):
        traj_states = []
        env = get_env(self.args["task"])
        with torch.no_grad():
            results = {}
            state, done = env.reset(), False
            traj_states.append(state)
            self.reset()
            lst_action = torch.zeros((1, 1, self.args['action_shape'])).to(self.device)
            hidden_policy = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
            hidden_value = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
            rewards = 0
            lengths = 0
            horizon_pos = 0
            state = state[np.newaxis]  # 这里增加了数据的维度，当做batch为1在处理
            state = torch.from_numpy(state).float().to(self.device)
            current_nonterm = np.ones((1), dtype=bool)
            while not done:
                action, _, hidden_policy_next = self.get_action(state, lst_action, hidden_policy, deterministic=False)
                if random:
                    action = self.args["action_space"].sample()
                    action = torch.from_numpy(action[np.newaxis]).float().to(self.device)
                _, hidden_value_next = self.get_value(state, action, lst_action, hidden_value)
                # assert _ is None
                use_action = action.cpu().numpy().reshape(-1)
                next_state, reward, done, _ = env.step(use_action)
                traj_states.append(next_state)

                next_state = next_state[np.newaxis]
                next_state = torch.from_numpy(next_state).float().to(self.device)
                reward = np.array(reward)[np.newaxis, np.newaxis]
                done = np.array(done)[np.newaxis, np.newaxis]

                nonterm_mask = ~done.squeeze(-1)

                samples = {
                    "observations": state.cpu().numpy(),
                    "actions": action.cpu().numpy(),
                    "next_observations": next_state.cpu().numpy(),
                    "rewards": reward,
                    "terminals": done,
                    "last_actions": lst_action.cpu().numpy(),
                    "valid": current_nonterm.reshape(-1, 1),
                    "value_hidden": hidden_value.cpu().numpy(),
                    "policy_hidden": hidden_policy.cpu().numpy()
                }
                # if horizon_pos == pool.max_traj_len - 1:
                #     samples['next_observations'] = samples['observations']

                samples = {k: np.expand_dims(v, 1) for k, v in samples.items()}
                num_samples = 1
                assert num_samples == 1
                index = np.arange(
                    pool._pointer, pool._pointer + num_samples
                ) % pool._max_size
                for k in samples:
                    pool.fields[k][index, horizon_pos] = samples[k][:, 0]

                current_nonterm = current_nonterm & nonterm_mask
                lst_action = action
                state = next_state
                hidden_policy = hidden_policy_next
                hidden_value = hidden_value_next

                rewards += reward
                lengths += 1
                horizon_pos += 1
                if horizon_pos >= pool.max_traj_len:
                    horizon_pos -= pool.max_traj_len
                    pool._pointer += 1
                    pool._pointer %= pool._max_size
                    pool._size = min(pool._max_size, pool._size + num_samples)

            pool._pointer += 1
            pool._pointer %= pool._max_size
            pool._size = min(pool._max_size, pool._size + num_samples)

        if log_transition:
            return rewards, lengths, np.array(traj_states)

        return (rewards, lengths)


    def eval_on_sim_env(self, degree=[0.5, 0.75, 1.5, 2.0], type=None):
        print("[DEBUG]: test on sim env")
        save_env(self.args["task"])
        if isinstance(degree, float):
            degree = [degree]
        tested = dict()
        # 在原本的环境上测试
        env = get_env(self.args["task"])
        tested["initial"] = self.eval_on_real_env(number_runs=5, env=env)

        # 在diverse任务上进行测试
        if type is None:
            type = ["fric", "grav", "dens"]
        elif isinstance(type, str):
            type = [type]
        for t in type:
            for d in degree:
                info = t + "-" + str(d)
                print(f"[DEBUG]: test on sim env ${info}")
                update_config_files(self.args["task"], t, d)
                env = get_env(self.args["task"])
                tested[info] = self.test_on_real_env(number_runs=5, env=env)
        reset_env(self.args["task"])
        return tested

    def save(self, save_path):
        assert save_path, "save path cannot be None!"
        if not os.path.exists(save_path):
            os.makedirs(save_path)
        torch.save(self.policy_gru.state_dict(), os.path.join(save_path, "policy_gru.pt"))
        torch.save(self.actor.state_dict(), os.path.join(save_path, "actor.pt"))
        torch.save(self.value_gru.state_dict(), os.path.join(save_path, "value_gru.pt"))
        torch.save(self.q1.state_dict(), os.path.join(save_path, "q1.pt"))
        torch.save(self.q2.state_dict(), os.path.join(save_path, "q2.pt"))
        torch.save(self.log_alpha.data, os.path.join(save_path, "log_alpha.pt"))
        torch.save(self.target_q1.state_dict(), os.path.join(save_path, "target_q1.pt"))
        torch.save(self.target_q2.state_dict(), os.path.join(save_path, "target_q2.pt"))
        torch.save(self.actor_optim.state_dict(), os.path.join(save_path, "actor_optim.pt"))
        torch.save(self.critic_optim.state_dict(), os.path.join(save_path, "critic_optim.pt"))
        torch.save(self.alpha_optim.state_dict(), os.path.join(save_path, "alpha_optim.pt"))

    def load(self, load_path):
        assert load_path, "load path cannot be None!"
        self.policy_gru.load_state_dict(torch.load(os.path.join(load_path, "policy_gru.pt"), map_location=self.device))
        self.actor.load_state_dict(torch.load(os.path.join(load_path, "actor.pt"), map_location=self.device))
        self.value_gru.load_state_dict(torch.load(os.path.join(load_path, "value_gru.pt"), map_location=self.device))
        self.q1.load_state_dict(torch.load(os.path.join(load_path, "q1.pt"), map_location=self.device))
        self.q2.load_state_dict(torch.load(os.path.join(load_path, "q2.pt"), map_location=self.device))
        self.target_q1.load_state_dict(torch.load(os.path.join(load_path, "target_q1.pt"), map_location=self.device))
        self.target_q2.load_state_dict(torch.load(os.path.join(load_path, "target_q2.pt"), map_location=self.device))
        self.log_alpha.data = torch.load(os.path.join(load_path, "log_alpha.pt"), map_location=self.device)
        self.actor_optim.load_state_dict(
            torch.load(os.path.join(load_path, "actor_optim.pt"), map_location=self.device))
        self.critic_optim.load_state_dict(
            torch.load(os.path.join(load_path, "critic_optim.pt"), map_location=self.device))
        self.alpha_optim.load_state_dict(
            torch.load(os.path.join(load_path, "alpha_optim.pt"), map_location=self.device))

    def state_dict(self):
        return {
            "policy_gru": self.policy_gru.state_dict(),
            "value_gru": self.value_gru.state_dict(),
            "actor": self.actor.state_dict(),
            "q1": self.q1.state_dict(),
            "q2": self.q2.state_dict(),
            "target_q1": self.target_q1.state_dict(),
            "target_q2": self.target_q2.state_dict(),
            "log_alpha": self.log_alpha.data,
            "actor_optim": self.actor_optim.state_dict(),
            "critic_optim": self.critic_optim.state_dict(),
            "alpha_optim": self.alpha_optim.state_dict(),
        }

    def load_state_dict(self, state_dict):
        self.policy_gru.load_state_dict(state_dict["policy_gru"])
        self.value_gru.load_state_dict(state_dict["value_gru"])
        self.actor.load_state_dict(state_dict["actor"])
        self.q1.load_state_dict(state_dict["q1"])
        self.q2.load_state_dict(state_dict["q2"])
        self.target_q1.load_state_dict(state_dict["target_q1"])
        self.target_q2.load_state_dict(state_dict["target_q2"])
        self.log_alpha.data = state_dict["log_alpha"]
        self.actor_optim.load_state_dict(state_dict["actor_optim"])
        self.critic_optim.load_state_dict(state_dict["critic_optim"])
        self.alpha_optim.load_state_dict(state_dict["alpha_optim"])