import torch
import torch.nn as nn
from torch.distributions import kl, Normal, kl_divergence
import numpy as np
import os
from copy import deepcopy

from offlinerl.utils.net.autoencoder import VariationalDecoder

from UtilsRL.monitor import Monitor
from UtilsRL.logger import TensorboardLogger
from UtilsRL.misc.decorator import profile

from offlinerl.data.d4rl import get_dateset_info, load_d4rl_buffer
from offlinerl.utils.env import get_env
from offlinerl.utils.net.ensemble import ModuleListTransition, EnsembleTransition
from offlinerl.utils.net.maple_net import DiscriminatorNetwork, SVNetwork
from offlinerl.utils.ensemble import ParallelRDynamics, ChunkRDynamics
from offlinerl.utils.simple_replay_pool import SimpleReplayTrajPool, SimpleReplayPool, TorchDatasetWrapper
from offlinerl.agent.sac import RNNSACAgent, SACAgent, TransformerSACAgent
import offlinerl.utils.loader as loader
from offlinerl.utils.rollout import fix_model_rollout, pbt_rollout, meta_rollout, model_rollout, policy_rollout
from offlinerl.utils.net.autoencoder import VariationalEncoder, TransformerEncoder
from offlinerl.utils.function import product_of_gaussians
from offlinerl.utils.meta_model import MetaDynamicsAgent
from offlinerl.utils.callback_function import EarlyStopCallback
from offlinerl.utils.net.terminal_check import is_terminal


def reset_param(env, type, param):
    if type == 'grav':
        env.unwrapped.model.opt.gravity[2] = param * (-9.81)
    if type == 'dof_damping':
        for idx in range(len(env.unwrapped.model.dof_damping)):
            env.unwrapped.model.dof_damping[idx] *= param
    return env

use_tanh=False
class REDMTrainer():
    def __init__(self, args):
        self.args = args
        self.logger: TensorboardLogger = args["logger"]
        self.device = args["device"]

        self.name = self.args["task"][5:-3]

        self.use_tanh = use_tanh

        # env info
        buffer = load_d4rl_buffer(args["task"])
        info = get_dateset_info(args["task"], buffer=buffer)
        env = {}
        for attr in ["obs_max", "obs_min", "obs_mean", "obs_std",
                     "rew_max", "rew_min", "rew_mean", "rew_std",
                     "obs_shape", "obs_space", "action_shape", "action_space"]:
            if isinstance(info[attr], np.ndarray):
                env[attr] = torch.tensor(info[attr], dtype=torch.float32, requires_grad=False, device=self.device)
            else:
                env[attr] = info[attr]
        obs_range = env["obs_max"] - env["obs_min"]
        soft_expanding = obs_range * args["soft_expanding"]
        env["obs_max"], env["obs_min"] = env["obs_max"] + soft_expanding, env["obs_min"] - soft_expanding
        for key, val in env.items():
            args[key] = val

        args["data_name"] = args["task"][5:]
        args["env_pool_size"] = int((buffer.shape[0] / args["horizon"]) * 1.2)
        args["env_pool_size_step"] = int((buffer.shape[0]) * 1.2)
        args["target_entropy"] = - args["action_shape"]
        del buffer
        torch.cuda.empty_cache()

        self.dynamics = None
        self.candidate_model = []
        # self.probe_policy = []
        # self.last_pp_state_dict = None
        # self.discriminator = DiscriminatorNetwork(self.args, self.args["obs_shape"], self.args["action_shape"]).to(
        #     self.device)
        # self.discriminator_optim = torch.optim.AdamW(self.discriminator.parameters(), lr=args["Discriminator"]["lr"],
        #                                              weight_decay=args["Discriminator"]["l2_loss_coef"])

        # self.meta_policy = TransformerSACAgent(args).to(self.device)
        self.meta_policy = RNNSACAgent(args).to(self.device)
        self.value = SVNetwork(args, args.obs_shape, args.action_shape).to(self.device)

        # self.env_pool = SimpleReplayPool(args.obs_space, args.action_space, args.env_pool_size_step)
        # loader.restore_pool_d4rl(self.env_pool, args.data_name, adapt=False, maxlen=args.horizon, device=args.device)

        self.traj_env_pool = SimpleReplayTrajPool(args.obs_space, args.action_space, args.horizon, args.rnn_hidden_dim,
                                                  args.env_pool_size)
        loader.restore_pool_d4rl(self.traj_env_pool, args["data_name"], adapt=True, maxlen=args["horizon"], \
                                 policy_hook=self.meta_policy.policy_gru, value_hook=self.meta_policy.value_gru,
                                 device=self.device)

        self.probe_epoch = 0
        self.discr_epoch = 0
        self.candidate_epoch = 0
        self.meta_epoch = 0

        self.probe_num = 0
        self.candidate_num = 0

        self.param_range = {
            'grav': [0.5, 1.0, 1.5],
            'dens': [0.5, 0.8, 1.0, 1.2, 1.5],
            'fric': [0.5, 0.8, 1.0, 1.2, 1.5],
            'dof_damping': [0.5, 0.8, 1.0, 1.2, 1.5]
        }

        self.type_list = ['grav']
        self.type = args.eval_type

    def train_bc_policy(self, path):
        # self.logger.log_str("start to train BC policy ...", type="WARNING")
        args = self.args

        torch.cuda.empty_cache()
        bc_policy = SACAgent(args)
        env_buffer = load_d4rl_buffer(args["task"])
        for i_epoch in Monitor("train BC").listen(range(args["BC"]["train_epoch"])):
            torch.cuda.empty_cache()
            bc_train_loss = dict()
            for i_update in range(args["BC"]["train_update"]):
                batch = env_buffer.sample(args["BC"]["batch_size"])
                res = bc_policy.train_policy(batch, behavior_cloning=True)
                for key in res:
                    bc_train_loss[key] = bc_train_loss.get(key, 0) + res[key]
            for key in bc_train_loss:
                bc_train_loss[key] /= args["BC"]["train_update"]
            bc_eval_loss = bc_policy.eval_on_real_env()
            bc_train_loss.update(bc_eval_loss)
            self.logger.log_scalars("BC", bc_train_loss, step=i_epoch)

        # self.logger.log_str(f"BC policy training is done, saving to {path}", type="WARNING")
        bc_policy.save(path)
        self.load_bc_policy(path)

    def load_bc_policy(self, path):
        # self.logger.log_str(f"loading bc policy from {path} ...", type="WARNING")
        self.bc_policy = SACAgent(self.args).to(self.device)
        self.bc_policy.load(path)
        self.bc_policy.requires_grad_(False)
        eval_result = self.bc_policy.eval_on_real_env()
        # self.logger.log_scalars("BC policy eval result", eval_result)
        torch.cuda.empty_cache()

    def train_dynamics(self, path):
        # self.logger.log_str("start to train dynamics ...", type="WARNING")
        args = self.args

        torch.cuda.empty_cache()
        buffer = load_d4rl_buffer(self.args["task"])
        data_size = len(buffer)
        val_size = min(int(data_size * 0.2) + 1, 1000)
        train_size = data_size - val_size
        train_splits, val_splits = torch.utils.data.random_split(range(data_size), (train_size, val_size))
        train_buffer = buffer[train_splits.indices]
        val_buffer = buffer[val_splits.indices]

        dynamics = ParallelRDynamics(
            obs_dim=args["obs_shape"],
            action_dim=args["action_shape"],
            hidden_features=args["Dynamics"]["hidden_layer_size"],
            hidden_layers=args["Dynamics"]["hidden_layer_num"],
            ensemble_size=args["Dynamics"]["init_num"],
            normalizer=args["Dynamics"]["normalizer"],
            obs_mean=args["obs_mean"],
            obs_std=args["obs_std"],
            tanh=use_tanh,
        ).to(self.device)
        dynamics_optim = torch.optim.AdamW(dynamics.split_parameters(), lr=args["Dynamics"]["lr"],
                                           weight_decay=args["Dynamics"]["l2_loss_coef"])
        val_losses = [100000 for i in range(dynamics.ensemble_size)]
        from_which_epoch = [-1 for i in range(dynamics.ensemble_size)]
        best_snapshot = [dynamics.get_single_transition(i) for i in range(dynamics.ensemble_size)]

        batch_step = 0
        cnt = 0
        batch_size = args["Dynamics"]["batch_size"]
        for epoch in Monitor("train transition").listen(range(args["Dynamics"]["max_epoch"])):
            idxs = np.random.randint(train_buffer.shape[0], size=[dynamics.ensemble_size, train_buffer.shape[0]])
            for batch_num in range(int(np.ceil(idxs.shape[-1] / batch_size))):
                batch_step += 1
                batch_idxs = idxs[:, batch_num * batch_size:(batch_num + 1) * batch_size]
                batch = train_buffer[batch_idxs].to_torch(device=self.device)
                dist = dynamics(torch.cat([batch["obs"], batch["act"]], dim=-1))

                # norm_dist = Normal(
                #     torch.zeros_like(dist.mean),
                #     torch.ones_like(dist.stddev)
                # )
                # kl_cons = kl_divergence(dist, norm_dist).mean()
                kl_cons = 0.

                model_loss = (- dist.log_prob(torch.cat([batch["obs_next"], batch["rew"]], dim=-1))).mean(dim=1).mean()
                # decay_loss = dynamics.get_decay_loss()
                clip_loss = 0.01 * (2. * dynamics.max_logstd).mean() - 0.01 * (2. * dynamics.min_logstd).mean() if \
                    args["Dynamics"]["train_with_clip_loss"] else 0
                loss = model_loss #+ decay_loss  #+ clip_loss
                self.logger.log_scalars("Dynamics", {
                    "model_loss": model_loss.detach().cpu().item(),
                    "all_loss": loss.detach().cpu().item(),
                    "max_logstd": dynamics.max_logstd.data.cpu().mean().item(),
                    "min_logstd": dynamics.min_logstd.data.cpu().mean().item(),
                    "mean_logstd": dist.scale.cpu().mean().item(),
                    # "decay_loss": decay_loss.detach().cpu().item()
                    # "kl_loss": kl_cons.cpu().item()
                }, step=batch_step)
                dynamics_optim.zero_grad()
                loss.backward()
                dynamics_optim.step()

            new_val_losses = list(self._eval_dynamics(dynamics, val_buffer, inc_var_loss=args["Dynamics"][
                "eval_with_var_loss"]).cpu().numpy())

            indexes = []
            for i, new_loss, old_loss in zip(range(len(val_losses)), new_val_losses, val_losses):
                # if new_loss < old_loss:
                if (old_loss - new_loss) / np.abs(old_loss) > 0.0:
                    indexes.append(i)
                    val_losses[i] = new_loss
            # self.logger.log_str(f"Epoch {epoch}: updated {len(indexes)} models {indexes}", type="LOG")
            # self.logger.log_str(f"model losses are {val_losses}", type="LOG")
            self.logger.log_scalar("Dynamics/val_loss", np.mean(new_val_losses), batch_step)

            if len(indexes) > 0:
                for idx in indexes:
                    best_snapshot[idx] = dynamics.get_single_transition(idx)
                    from_which_epoch[idx] = epoch
                cnt = 0
            else:
                cnt += 1

            if cnt >= 5 and epoch > self.args["Dynamics"]["min_epoch"]:
                # self.logger.log_str(f"early stopping, final best losses are {val_losses}", type="LOG")
                # self.logger.log_str(f"from which epoch: {from_which_epoch}", type="LOG")
                break

        pairs = [(idx, val_loss) for idx, val_loss in enumerate(val_losses)]
        pairs = sorted(pairs, key=lambda x: x[1])
        selected_indexes = [p[0] for p in pairs[:args["Dynamics"]["select_num"]]]
        if not os.path.exists(path):
            os.makedirs(path)
        for i, idx in enumerate(selected_indexes):
            torch.save(best_snapshot[idx], os.path.join(path, str(i)))
            self.candidate_model.append(best_snapshot[idx])
        del dynamics

        # self.logger.log_str(f"dynamics training is done, saving to {path}")
        print(len(self.candidate_model))
        self.load_dynamics(path)

    def load_dynamics(self, path):
        # self.logger.log_str("load dynamics from {}".format(path), type="WARNING")
        models = [
            torch.load(os.path.join(path, name), map_location='cpu') \
            for name in os.listdir(path)
        ]
        self.candidate_model = []
        self.base_model = []
        self.new_model = []
        for model in models:
            self.candidate_model.append(model)
            self.base_model.append(model)
        print(len(self.candidate_model))
        self.dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=use_tanh).to(self.device)
        # self.dynamics.requires_grad_(False)
        del models
        torch.cuda.empty_cache()

    def train_mainloop(self, path):
        torch.cuda.empty_cache()
        if self.args.start_epoch > 0:
            self.load_mainloop(path, self.args.start_epoch)
        else:
            self.init_flag = True
            for i_epoch in Monitor("main loop").listen(range(self.args["Meta"]["init_epoch"])):
                self.train_meta_policy(path)
        self.init_flag = False

        self.eval_model_data(path)

        self.logger.log_scalars("Meta", {"Meta-{}".format(self.args["data_name"]): self.args.start_epoch})
        for i_epoch in Monitor("main loop").listen(range(self.args.start_epoch, self.args.total_epoch)):
            print(self.candidate_num)
            assert len(self.candidate_model) == self.candidate_num + self.args["Dynamics"]["select_num"]
            self.train_meta_policy(path)

            if len(self.new_model) <= self.args["Meta"]["max_model"]:
                self.train_flag = False
                self.train_candidate_set(path)
                assert self.train_flag
                # self.candidate_model.pop(0)

            if i_epoch % 5 == 0:
                self.eval_model_data(path)

            for t in self.type_list:
                self.type = t
                for param in self.param_range[self.type]:
                    env = get_env(self.args["task"])
                    # env.reset_grav(param)
                    reset_param(env, self.type, param)
                    loss_max, loss_min, lossind_max, lossind_min = self.eval_env_cover(env=env)

                    self.logger.log_scalars("Coverloss_{}_{}".format(self.type, str(param)), {
                        # "model_loss": model_loss.detach().cpu().item(),
                        # "sl_loss": sl_loss.detach().cpu().item(),
                        "res_loss_max": loss_max.detach().cpu().item(),
                        "res_loss_min": loss_min.detach().cpu().item(),
                        "res_lossind_max": lossind_max.detach().cpu().item(),
                        "res_lossind_min": lossind_min.detach().cpu().item(),
                    }, step=i_epoch)

            if i_epoch % self.args.mainloop_save_interval == 0:
                # save candidate models
                for cmid, cm in enumerate(self.candidate_model):
                    target = os.path.join(path, str(i_epoch), "candidate_model")
                    if not os.path.exists(target):
                        os.makedirs(target)
                    torch.save(cm, os.path.join(target, str(cmid)))

                # save meta policy
                self.meta_policy.save(os.path.join(path, str(i_epoch), "meta_policy"))

        # torch.cuda.empty_cache()
        # self.dynamics = ParallelRDynamics.from_single_transition(self.candidate_model).to(self.device)
        # for i_epoch in Monitor("meta loop").listen(range(self.args.start_epoch, self.args.total_epoch)):
        #     self.train_meta_policy(path)
        #     if i_epoch % self.args.mainloop_save_interval == 0:
        #         self.meta_policy.save(os.path.join(path, str(i_epoch + self.args.total_epoch), "meta_policy"))

    def load_mainloop(self, path, start_epoch):
        epoch_path = os.path.join(path, str(start_epoch))
        if not os.path.exists(epoch_path):
            raise ValueError("load path not found")

        # load candidate
        # self.logger.log_str(f"load candidate set from {path}", type="WARNING")
        self.candidate_model = [
            torch.load(os.path.join(epoch_path, "candidate_model", id), map_location='cpu') \
            for id in os.listdir(epoch_path, "candidate_model")
        ]
        # load meta policy
        self.meta_policy.load(os.path.join(epoch_path, "meta_policy"))


    def compute_ret(self, rewards, masks, max_length):
        args = self.args
        returns = torch.zeros_like(rewards)
        prev_returns = 0.

        for i in reversed(range(max_length)):
            returns[:, i, :] = rewards[:, i, :] + args["discount"] * (masks[:, i, :]) * prev_returns
            prev_returns = returns[:, i, :]

        return returns

    def compute_gae(self, rewards, values, masks, max_length):
        args = self.args
        returns = torch.zeros_like(rewards)
        advantages = torch.zeros_like(rewards)
        deltas = torch.zeros_like(rewards)
        prev_returns = 0.
        prev_values = 0.
        prev_advantages = 0.
        lam = 0.95

        for i in reversed(range(max_length)):
            returns[:, i, :] = rewards[:, i, :] + args["discount"] * (masks[:, i, :]) * prev_returns
            deltas[:, i, :] = rewards[:, i, :] + args["discount"] * (masks[:, i, :]) * prev_values
            advantages[:, i, :] = deltas[:, i, :] + args["discount"] * lam * (masks[:, i, :]) * prev_advantages

            prev_returns = returns[:, i, :]
            prev_values = values[:, i, :]
            prev_advantages = advantages[:, i, :]

        if advantages[:, 0, :].std() > 0.001:
            advantages = (advantages - advantages.mean(dim=0, keepdim=True)) / (
                    advantages.std(dim=0, keepdim=True) + 1e-5)

        return returns, advantages

    def compute_pg_loss(self, states, actions, next_states, masks, advantages, returns, old_logprobs):
        clip = 0.2
        alpha = 0.99
        coef_value = 0.5
        coef_entropy = 0.001
        values = self.value(states, actions)
        next_obs_dists = self.new_dynamics(torch.cat([states, actions], dim=-1),
                                           use_res=True)  # 这里得到的是一个分布

        rewards = next_obs_dists.sample()[:, :, -1:]
        next_samples = torch.cat([next_states.unsqueeze(0) - states.unsqueeze(0), rewards], dim=-1)

        new_logprobs = next_obs_dists.log_prob(next_samples)
        new_logprobs = new_logprobs[:, :, :-1].mean(dim=-1, keepdim=True).clamp(-20, 2)
        total_num = masks.sum() if masks.sum() > 1 else 10

        ratio = torch.exp(new_logprobs - old_logprobs)
        surr1 = ratio * advantages
        surr2 = ratio.clamp(1 - clip, 1 + clip) * advantages
        loss_surr = - torch.sum(torch.min(surr1, surr2) * masks * alpha) / total_num
        loss_value = torch.sum((values - returns).pow(2) * masks) / total_num
        loss_entropy = torch.sum(torch.exp(new_logprobs) * new_logprobs * masks) / total_num
        loss = loss_surr + coef_value * loss_value + coef_entropy * loss_entropy

        return loss

    # @profile
    def train_candidate_set(self, path):
        print("train candidate model")
        args = self.args
        args["ablation"]["sl"] = False
        # idx = np.random.choice(self.args["Dynamics"]["select_num"])
        idx = len(self.candidate_model) % self.args["Dynamics"]["select_num"]
        # idx = 0
        self.dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=use_tanh)
        self.base_dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=use_tanh)
        self.new_dynamics = self.base_dynamics.get_single_transition(idx)
        # old_dynamics = self.base_dynamics.get_single_transition(idx)
        # self.new_dynamics = ParallelRDynamics(
        #     obs_dim=args["obs_shape"],
        #     action_dim=args["action_shape"],
        #     hidden_features=args["Dynamics"]["hidden_layer_size"],
        #     hidden_layers=args["Dynamics"]["hidden_layer_num"],
        #     ensemble_size=1,
        #     normalizer=args["Dynamics"]["normalizer"],
        #     obs_mean=args["obs_mean"],
        #     obs_std=args["obs_std"],
        #     tanh=True,
        # ).to(self.device)

        # self.dynamics = self.dynamics.to(self.device)
        self.new_dynamics = self.new_dynamics.to(self.device)
        # self.new_dynamics.max_logstd.requires_grad_(False)
        # self.new_dynamics.min_logstd.requires_grad_(False)
        # old_dynamics = old_dynamics.to(self.device)

        optim_params = list(self.new_dynamics.parameters())
        if self.value is not None:
            optim_params += list(self.value.parameters())
        new_dynamics_optim = torch.optim.Adam(params=optim_params, lr=args["Candidate"]["lr"],
                                               weight_decay=args["Candidate"]["l2_loss_coef"])

        ### train dynamics
        # self.logger.log_str(f"Start to train candidate model {self.candidate_num}", type="WARNING")
        # model_pool = None
        batch_step = 0
        batch_step_aux = 0
        torch.cuda.empty_cache()
        # for i_epoch in range(args["Candidate"]["train_epoch"]):
            # for i_iter in range(args["Candidate"]["train_update"]):
        for i_epoch in range(args["Candidate"]["train_epoch"]):
            self.train_flag = True
            for i_iter in range(args["Candidate"]["train_update"]):
                batch = self.traj_env_pool.random_batch_for_initial(args["Candidate"]["train_batch_size"])
                # sample
                obs = torch.from_numpy(batch["observations"]).to(self.device)
                # buffer_obs = obs
                # buffer_action = torch.from_numpy(batch["actions"]).to(args.device)
                # buffer_next_obs = torch.from_numpy(batch["next_observations"]).to(self.device)
                # buffer_rew = torch.from_numpy(batch["rewards"]).to(self.device)
                # action = torch.from_numpy(batch["actions"]).to(self.device)
                lst_action = torch.from_numpy(batch["last_actions"]).to(args.device)
                value_hidden = torch.from_numpy(batch["value_hidden"]).to(args.device)
                policy_hidden = torch.from_numpy(batch["policy_hidden"]).to(args.device)
                self.meta_policy.reset()
                current_nonterm = np.ones([len(obs), 1], dtype=bool)

                logprobs = []
                rewards = []
                rewards_meta = []
                rewards_search = []
                masks = []
                states = []
                actions = []
                next_states = []
                max_length = 0
                with torch.no_grad():
                    for h in range(args["Candidate"]["candidate_horizon"]):
                        batch_size = obs.shape[0]
                        with torch.no_grad():
                            action, _, mu, logstd, policy_hidden_next = self.meta_policy.get_action(obs, lst_action,
                                                                                                    policy_hidden,
                                                                                                    deterministic=False,
                                                                                                    out_mean_std=True)
                            value, value_hidden_next = self.meta_policy.get_value(obs, action, lst_action,
                                                                                                    value_hidden)

                        next_obs_dists = self.new_dynamics(torch.cat([obs, action], dim=-1),
                                                           use_res=True)  # 这里得到的是一个分布

                        next_sample = next_obs_dists.sample()
                        next_log_prob = next_obs_dists.log_prob(next_sample)
                        next_log_prob = next_log_prob[:, :, :-1].mean(dim=-1, keepdim=True).clamp(-20, 2)
                        logprobs.append(next_log_prob[0])
                        next_obs = next_sample[:, :, :-1] + obs
                        next_obs = next_obs[0, np.arange(batch_size)]
                        next_rew = next_sample[:, :, -1:]
                        next_rew = next_rew[0, np.arange(batch_size)]
                        obs = torch.clamp(obs, args["obs_min"], args["obs_max"])
                        next_obs = torch.clamp(next_obs, args["obs_min"], args["obs_max"])

                        with torch.no_grad():
                            next_action, _, _, _, _ = self.meta_policy.get_action(next_obs, action, policy_hidden_next,
                                                                                                    deterministic=False,
                                                                                                    out_mean_std=True)
                            value_next, _ = self.meta_policy.get_value(next_obs, next_action, action,
                                                                                                    value_hidden_next)

                        term = is_terminal(obs.detach().cpu().numpy(), action.detach().cpu().numpy(),
                                           next_obs.detach().cpu().numpy(),
                                           args["task"])

                        dones = torch.from_numpy(term).to(self.args.device).float()
                        # rew1 = - rew_fn(obs, action, next_obs, torch.from_numpy(term).to(self.device))
                        # rew1 = - next_rew.clamp(args["rew_min"], args["rew_max"])
                        # rew1 = rew1 * (1. - dones) + (-value) * (dones)
                        rew1 = (value - self.args['discount'] * value_next) * (1. - dones)
                        rew2 = self.search_rew(next_obs) * torch.from_numpy(~term).to(self.device)
                        # rew2 = torch.zeros_like(rew1)
                        # rew = - next_rew.clamp(self.rew_min, self.rew_max)
                        nonterm_mask = ~term
                        current_nonterm = current_nonterm & nonterm_mask
                        rew = rew1 + args["Candidate"]["ratio_aux"] * rew2
                        # rew = rew1 + rew2

                        masks.append(torch.from_numpy(current_nonterm).to(self.device))
                        states.append(obs)
                        next_states.append(next_obs)
                        rewards.append(rew)
                        rewards_meta.append(rew1)
                        rewards_search.append(rew2)
                        actions.append(action)

                        obs = next_obs.detach()
                        lst_action = action
                        policy_hidden = policy_hidden_next
                        value_hidden = value_hidden_next

                        max_length += 1
                        if (current_nonterm).sum() <= 0:
                            break

                old_logprobs = torch.stack(logprobs, dim=1)
                rewards = torch.stack(rewards, dim=1)
                rewards1 = torch.stack(rewards_meta, dim=1)
                rewards2 = torch.stack(rewards_search, dim=1)
                masks = torch.stack(masks, dim=1)
                states = torch.stack(states, dim=1)
                actions = torch.stack(actions, dim=1)
                next_states = torch.stack(next_states, dim=1)
                values = self.value(states, actions).detach()
                returns, advantages = self.compute_gae(rewards, values, masks, max_length)
                returns1 = self.compute_ret(rewards1, masks, max_length)
                returns2 = self.compute_ret(rewards2, masks, max_length)

                self.logger.log_scalars("Candidate", {
                    "norm_returns1_mean": returns1.mean().detach().cpu().item() if not args["ablation"][
                        "sl"] else 0,
                    "norm_returns1_std": returns1.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "norm_returns2_mean": returns2.mean().detach().cpu().item() if not args["ablation"][
                        "sl"] else 0,
                    "norm_returns2_std": returns2.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "norm_returns_mean": returns.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "norm_returns_std": returns.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards1_mean": rewards1.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards1_std": rewards1.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards2_mean": rewards2.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards2_std": rewards2.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards_mean": rewards.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards_std": rewards.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards1_max": rewards1.max().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards1_min": rewards1.min().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards2_max": rewards2.max().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards2_min": rewards2.min().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards_max": rewards.max().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "rewards_min": rewards.min().detach().cpu().item() if not args["ablation"]["sl"] else 0,
                    "length": max_length,
                }, step=batch_step)

                states = states.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                actions = actions.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                next_states = next_states.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                masks = masks.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                advantages = advantages.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                returns = returns.view(args["Candidate"]["train_batch_size"] * max_length, -1)
                old_logprobs = old_logprobs.view(args["Candidate"]["train_batch_size"] * max_length, -1)

                # pg loss
                kl_cons = 0.
                pg_loss = 0.
                for k in range(self.args["Candidate"]["ppo_epoch"]):
                    for _ in range(int(args["Candidate"]["train_batch_size"] * max_length //
                                       self.args["Candidate"]["ppo_batch_size"])):
                        index = np.random.choice(args["Candidate"]["train_batch_size"] * max_length,
                                                 self.args["Candidate"]["ppo_batch_size"], replace=False)

                        bsta = states[index]
                        bact = actions[index]
                        bnes = next_states[index]
                        bmas = masks[index]
                        badv = advantages[index]
                        bret = returns[index]
                        bold = old_logprobs[index]

                        buffer_batch = self.traj_env_pool.random_batch_for_initial(args["Candidate"]["cons_batch_size"])

                        # sample
                        buffer_obs = torch.from_numpy(buffer_batch["observations"]).to(self.device)
                        buffer_action = torch.from_numpy(buffer_batch["actions"]).to(args.device)
                        buffer_next_obs = torch.from_numpy(buffer_batch["next_observations"]).to(self.device)
                        buffer_rew = torch.from_numpy(buffer_batch["rewards"]).to(self.device)

                        buffer_dist = self.new_dynamics(torch.cat([buffer_obs, buffer_action], dim=-1))

                        model_loss = (- buffer_dist.log_prob(torch.cat([buffer_next_obs, buffer_rew], dim=-1).unsqueeze(0))).mean(
                            dim=1).mean()
                        # decay_loss = self.new_dynamics.get_decay_loss()
                        clip_loss = 0.01 * (2. * self.new_dynamics.max_logstd).mean() - 0.01 * (
                                    2. * self.new_dynamics.min_logstd).mean() if \
                            args["Dynamics"]["train_with_clip_loss"] else 0
                        cons_loss = model_loss #+ decay_loss + clip_loss

                        # kl_cons = kl_divergence(bdist, btdist).sum()

                        pg_loss = self.compute_pg_loss(bsta, bact, bnes, bmas, badv, bret, bold)

                        # loss = 0.01 * pg_loss + cons_loss
                        loss = cons_loss + 0.01 * pg_loss

                        new_dynamics_optim.zero_grad()
                        loss.backward()
                        torch.nn.utils.clip_grad_norm_(self.new_dynamics.parameters(), 10)
                        new_dynamics_optim.step()

                        self.logger.log_scalars("Candidate_{}".format(self.candidate_num), {
                            "model_loss": model_loss.detach().cpu().item(),
                            # "decay_loss": decay_loss.detach().cpu().item(),
                            # "sl_loss": sl_loss.detach().cpu().item(),
                            "cons_loss": cons_loss.detach().cpu().item(),
                            "pg_loss": pg_loss.detach().cpu().item() if not args["ablation"]["sl"] else 0,
                            "max_logstd": self.new_dynamics.max_logstd.data.cpu().mean().item(),
                            "min_logstd": self.new_dynamics.min_logstd.data.cpu().mean().item(),
                            # "mean_logstd": dist.scale.cpu().mean().item()
                        }, step=batch_step)
                        batch_step += 1
        ###
        self.candidate_epoch += args["Candidate"]["train_epoch"]
        self.candidate_num += 1

        new_dynamics = self.new_dynamics.cpu()
        self.candidate_model.append(new_dynamics)
        torch.save(new_dynamics, os.path.join(path, str(self.candidate_model)))
        self.new_model.append(new_dynamics)

    # @profile
    # def train_candidate_set(self, path):
    #     args = self.args
    #     idx = np.random.choice(self.args["Dynamics"]["select_num"])
    #     self.dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=True)
    #     self.new_dynamics = self.dynamics.get_single_transition(idx)
    #     old_dynamics = self.dynamics.get_single_transition(idx)
    #     # self.new_dynamics = ParallelRDynamics(
    #     #     obs_dim=args["obs_shape"],
    #     #     action_dim=args["action_shape"],
    #     #     hidden_features=args["Dynamics"]["hidden_layer_size"],
    #     #     hidden_layers=args["Dynamics"]["hidden_layer_num"],
    #     #     ensemble_size=1,
    #     #     normalizer=args["Dynamics"]["normalizer"],
    #     #     obs_mean=args["obs_mean"],
    #     #     obs_std=args["obs_std"],
    #     #     tanh=True,
    #     # ).to(self.device)
    #
    #     self.dynamics = self.dynamics.to(self.device)
    #     self.new_dynamics = self.new_dynamics.to(self.device)
    #     old_dynamics = old_dynamics.to(self.device)
    #     new_dynamics_optim = torch.optim.AdamW(self.new_dynamics.parameters(), lr=args["Candidate"]["lr"],
    #                                            weight_decay=args["Candidate"]["l2_loss_coef"])
    #     ### train dynamics
    #     self.logger.log_str(f"Start to train candidate model {self.candidate_num}", type="WARNING")
    #     # model_pool = None
    #     batch_step = 0
    #     batch_step_aux = 0
    #     torch.cuda.empty_cache()
    #     for i_epoch in range(args["Candidate"]["train_epoch"]):
    #
    #         for i_iter in range(args["Candidate"]["train_update"]):
    #             ### SL
    #             # batch = self.get_train_policy_batch(self.env_pool, self.env_pool,
    #             #                                     args["Discriminator"]["train_batch_size"])
    #             batch = self.traj_env_pool.random_batch_for_initial(args["Candidate"]["train_batch_size"])
    #             obs = torch.from_numpy(batch["observations"]).to(self.device)
    #             actions = torch.from_numpy(batch["actions"]).to(self.device)
    #             next_obs = torch.from_numpy(batch["next_observations"]).to(self.device)
    #             rewards = torch.from_numpy(batch["rewards"]).to(self.device)
    #
    #             self.logger.log_scalars("Candidate_{}".format(self.candidate_num), {
    #                 "obs_mean": obs.mean().detach().cpu().item(),
    #                 "next_obs_mean": next_obs.mean().detach().cpu().item()
    #             }, step=batch_step)
    #
    #             dist = self.new_dynamics(torch.cat([obs, actions], dim=-1))
    #             old_dist = old_dynamics(torch.cat([obs, action], dim=-1))
    #
    #             kl_cons = kl_divergence(dist, old_dist).mean()
    #
    #             ### PG
    #             pg_loss = 0
    #             # kl_cons = 0.
    #             if not args["ablation"]["sl"]:
    #                 obs = torch.from_numpy(batch["observations"]).to(self.device)
    #                 lst_action = torch.from_numpy(batch["last_actions"]).to(args.device)
    #                 value_hidden = torch.from_numpy(batch["value_hidden"]).to(args.device)
    #                 policy_hidden = torch.from_numpy(batch["policy_hidden"]).to(args.device)
    #                 current_nonterm = np.ones([len(obs), 1], dtype=bool)
    #
    #                 logprobs = []
    #                 rewards = []
    #                 masks = []
    #                 max_length = 0
    #                 for h in range(args["Candidate"]["candidate_horizon"]):
    #                     batch_size = obs.shape[0]
    #                     with torch.no_grad():
    #                         action, _, mu, logstd, policy_hidden_next = self.meta_policy.get_action(obs, lst_action,
    #                                                                                                 policy_hidden,
    #                                                                                                 deterministic=False,
    #                                                                                                 out_mean_std=True)
    #
    #                     next_obs_dists = self.new_dynamics(torch.cat([obs, action], dim=-1),
    #                                                        use_res=True)  # 这里得到的是一个分布
    #
    #                     next_sample = next_obs_dists.sample()
    #                     next_log_prob = next_obs_dists.log_prob(next_sample)
    #
    #                     # next_sample_pre = next_sample
    #                     # next_sample_pre = 0.5 * (next_sample[..., :-1].log1p() - (-next_sample[..., :-1]).log1p())
    #                     next_log_prob = next_log_prob.mean(dim=-1, keepdim=True).clamp(-20, 2)
    #                     # next_log_prob[:, :, 0] -= torch.sum(2 * (np.log(2) - next_sample_pre - torch.nn.functional.softplus(
    #                     #     -2 * next_sample_pre)), dim=-1)
    #
    #                     logprobs.append(next_log_prob[0])
    #                     next_obs = next_sample[:, :, :-1] + obs
    #                     next_obs = next_obs[0, np.arange(batch_size)]
    #                     next_rew = next_sample[:, :, -1:]
    #                     next_rew = next_rew[0, np.arange(batch_size)]
    #
    #                     term = is_terminal(obs.detach().cpu().numpy(), action.detach().cpu().numpy(),
    #                                        next_obs.detach().cpu().numpy(),
    #                                        args["task"])
    #
    #                     rew = - next_rew.clamp(self.args.rew_min, self.args.rew_max)
    #                     rewards.append(rew)
    #                     nonterm_mask = ~term
    #                     current_nonterm = current_nonterm & nonterm_mask
    #
    #                     masks.append(torch.from_numpy(current_nonterm).to(self.device))
    #                     obs = next_obs.detach()
    #                     lst_action = action
    #                     policy_hidden = policy_hidden_next
    #
    #                     max_length += 1
    #                     if (current_nonterm).sum() <= 0:
    #                         break
    #
    #                 logprobs = torch.stack(logprobs, dim=1)
    #                 rewards = torch.stack(rewards, dim=1)
    #                 masks = torch.stack(masks, dim=1)
    #                 returns = torch.zeros_like(rewards)
    #
    #                 for i in reversed(range(max_length)):
    #                     if i == (max_length - 1):
    #                         returns[:, i, :] = rewards[:, i, :]
    #                         continue
    #                     returns[:, i, :] = rewards[:, i, :] + args["discount"] * (masks[:, i, :]) * returns[:,
    #                                                                                                 i + 1, :]
    #                 if returns[:, 0, :].std() > 1.0:
    #                     norm_returns = (returns - returns.mean(dim=0, keepdim=True)) / (
    #                             returns.std(dim=0, keepdim=True) + 1e-5)
    #                 else:
    #                     norm_returns = returns
    #
    #                 pg_loss += ((-norm_returns * logprobs) * masks).sum() / (
    #                             masks.sum() + 1e-2)
    #             ###
    #             loss = self.args["Candidate"]["pg_coef"] * pg_loss + kl_cons
    #
    #             new_dynamics_optim.zero_grad()
    #             loss.backward()
    #             torch.nn.utils.clip_grad_norm_(self.new_dynamics.parameters(), 10)
    #             new_dynamics_optim.step()
    #
    #             self.logger.log_scalars("Candidate_{}".format(self.candidate_num), {
    #                 # "model_loss": model_loss.detach().cpu().item(),
    #                 # "sl_loss": sl_loss.detach().cpu().item(),
    #                 "cons_loss": kl_cons.detach().cpu().item(),
    #                 "pg_loss": pg_loss.detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "norm_returns_mean": norm_returns.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "norm_returns_std": norm_returns.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "max_logstd": self.new_dynamics.max_logstd.data.cpu().mean().item(),
    #                 "min_logstd": self.new_dynamics.min_logstd.data.cpu().mean().item(),
    #                 "mean_logstd": dist.scale.cpu().mean().item()
    #             }, step=batch_step)
    #             batch_step += 1
    #
    #         for i_iter in range(args["Candidate"]["train_update"]):
    #             # batch = self.get_train_policy_batch(self.env_pool, self.env_pool,
    #             #                                     args["Candidate"]["train_batch_size"])
    #             batch = self.traj_env_pool.random_batch_for_initial(args["Candidate"]["train_batch_size"])
    #
    #             ### PG
    #             pg_loss = 0
    #             # kl_cons = 0.
    #             if not args["ablation"]["sl"]:
    #
    #                 obs = torch.from_numpy(batch["observations"]).to(self.device)
    #                 lst_action = torch.from_numpy(batch["last_actions"]).to(args.device)
    #                 value_hidden = torch.from_numpy(batch["value_hidden"]).to(args.device)
    #                 policy_hidden = torch.from_numpy(batch["policy_hidden"]).to(args.device)
    #                 current_nonterm = np.ones([len(obs), 1], dtype=bool)
    #
    #                 logprobs = []
    #                 rewards = []
    #                 masks = []
    #                 max_length = 0
    #                 for h in range(args["Candidate"]["candidate_horizon"]):
    #                     batch_size = obs.shape[0]
    #                     with torch.no_grad():
    #                         action, _ = self.bc_policy.get_action(obs, deterministic=False, out_mean_std=False)
    #
    #                     next_obs_dists = self.new_dynamics(torch.cat([obs, action], dim=-1),
    #                                                        use_res=True)
    #
    #                     next_sample = next_obs_dists.sample()
    #                     next_log_prob = next_obs_dists.log_prob(next_sample)
    #
    #                     # next_sample_pre = next_sample
    #                     # next_sample_pre = 0.5 * (next_sample[..., :-1].log1p() - (-next_sample[..., :-1]).log1p())
    #                     next_log_prob = next_log_prob[:, :, :-1].mean(dim=-1, keepdim=True).clamp(-20, 2)
    #                     # next_log_prob[:, :, 0] -= torch.sum(
    #                     #     2 * (np.log(2) - next_sample_pre - torch.nn.functional.softplus(
    #                     #         -2 * next_sample_pre)), dim=-1)
    #
    #                     logprobs.append(next_log_prob[0])
    #                     next_obs = next_sample[:, :, :-1] + obs
    #                     next_obs = next_obs[0, np.arange(batch_size)]
    #                     next_rew = next_sample[:, :, -1:]
    #                     next_rew = next_rew[0, np.arange(batch_size)]
    #
    #                     term = is_terminal(obs.detach().cpu().numpy(), action.detach().cpu().numpy(),
    #                                        next_obs.detach().cpu().numpy(),
    #                                        args["task"])
    #
    #                     # rew = torch.ones_like(next_rew)
    #                     rew = self.search_rew(next_obs)
    #
    #                     # rew = next_rew.clamp(self.rew_min, self.rew_max)
    #                     rewards.append(rew)
    #                     nonterm_mask = ~term
    #                     current_nonterm = current_nonterm & nonterm_mask
    #
    #                     masks.append(torch.from_numpy(current_nonterm).to(self.device))
    #                     obs = next_obs.detach()
    #
    #                     max_length += 1
    #                     if (current_nonterm).sum() <= 0:
    #                         break
    #
    #                 logprobs = torch.stack(logprobs, dim=1)
    #                 rewards = torch.stack(rewards, dim=1)
    #                 masks = torch.stack(masks, dim=1)
    #                 returns = torch.zeros_like(rewards)
    #
    #                 # print(returns.shape)
    #
    #                 for i in reversed(range(max_length)):
    #                     if i == (max_length - 1):
    #                         returns[:, i, :] = rewards[:, i, :]
    #                         continue
    #                     returns[:, i, :] = rewards[:, i, :] + args["discount"] * (masks[:, i, :]) * returns[:,
    #                                                                                                 i + 1, :]
    #                 if returns[:, 0, :].std() > 1.0:
    #                     norm_returns = (returns - returns.mean(dim=0, keepdim=True)) / (
    #                             returns.std(dim=0, keepdim=True) + 1e-5)
    #                 else:
    #                     norm_returns = returns
    #
    #                 pg_loss += ((-norm_returns * logprobs) * masks).sum() / (
    #                         masks.sum() + 1e-2)
    #             ###
    #             loss = self.args["Candidate"]["pg_coef"] * pg_loss
    #
    #             new_dynamics_optim.zero_grad()
    #             loss.backward()
    #             torch.nn.utils.clip_grad_norm_(self.new_dynamics.parameters(), 10)
    #             new_dynamics_optim.step()
    #
    #             self.logger.log_scalars("Candidate_{}".format(self.candidate_num), {
    #                 # "model_loss": model_loss.detach().cpu().item(),
    #                 # "sl_loss": sl_loss.detach().cpu().item(),
    #                 # "cons_loss": kl_cons.detach().cpu().item(),
    #                 "pg_loss_aux": pg_loss.detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "norm_returns_mean_aux": norm_returns.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "norm_returns_std_aux": norm_returns.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "rewards_mean_aux": rewards.mean().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "rewards_std_aux": rewards.std().detach().cpu().item() if not args["ablation"]["sl"] else 0,
    #                 "max_logstd_aux": self.new_dynamics.max_logstd.data.cpu().mean().item(),
    #                 "min_logstd_aux": self.new_dynamics.min_logstd.data.cpu().mean().item(),
    #                 "max_rew_aux": self.args.rew_max,
    #                 "min_rew_aux": self.args.rew_min,
    #             }, step=batch_step_aux)
    #             batch_step_aux += 1
    #
    #     ###
    #     self.candidate_epoch += args["Candidate"]["train_epoch"]
    #     self.candidate_num += 1
    #
    #     new_dynamics = self.new_dynamics.cpu()
    #     self.candidate_model.append(new_dynamics)

    # def load_candidate_set(self, path):
    #     self.logger.log_str("load candidate set from {}".format(path), type="WARNING")
    #     models = [
    #         torch.load(os.path.join(path, name), map_location='cpu') \
    #         for name in os.listdir(path)
    #     ]
    #     self.candidate_model = []
    #     for model in models:
    #         self.candidate_model.append(model)
    #     self.dynamics = ParallelRDynamics.from_single_transition(self.candidate_model).to(self.device)
    #     # self.dynamics.requires_grad_(False)
    #     del models
    #     torch.cuda.empty_cache()

    def search_rew(self, obs):
        args = self.args
        # self.bc_policy = SACAgent(self.args).to(self.device)

        raw_batch_size = obs.shape[0]
        obs = obs.repeat(self.args["Candidate"]["seed"], 1)
        rew = torch.zeros_like(obs.sum(dim=-1, keepdim=True))
        current_nonterm = np.ones([len(obs), 1], dtype=bool)

        with torch.no_grad():
            for h in range(self.args["Candidate"]["search_horizon"]):
                batch_size = obs.shape[0]
                action, _ = self.bc_policy.get_action(obs, deterministic=False, out_mean_std=False)

                next_obs_dists = self.new_dynamics(torch.cat([obs, action], dim=-1), use_res=True)

                next_sample = next_obs_dists.sample()

                next_obs = next_sample[:, :, :-1] + obs
                next_obs = next_obs[0, np.arange(batch_size)]
                next_rew = next_sample[:, :, -1]
                next_rew = next_rew[0, np.arange(batch_size)].unsqueeze(-1).clamp(args["rew_min"], args["rew_max"])

                term = is_terminal(obs.detach().cpu().numpy(), action.detach().cpu().numpy(),
                                   next_obs.detach().cpu().numpy(),
                                   args["task"])

                next_obs = torch.clamp(next_obs, args["rew_min"], args["rew_max"])

                nonterm_mask = ~term
                current_nonterm = current_nonterm & nonterm_mask
                obs = next_obs.detach()
                rew += next_rew * torch.from_numpy(nonterm_mask).to(self.device)

        rew = rew.view(self.args['Candidate']['seed'], raw_batch_size, 1).max(dim=0)[0] / self.args['Candidate'][
            'search_horizon']
        return rew

    def eval_model_data(self, path=None):
        args = self.args
        meta_dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=False).to(self.device)

        total_data = {}
        for idx in range(len(self.candidate_model)):

            single_obs = []
            single_act = []
            single_next_obs = []

            batch = self.traj_env_pool.random_batch_for_initial(args["Candidate"]["train_batch_size"])
            obs = torch.from_numpy(batch["observations"]).to(self.device)
            lst_action = torch.from_numpy(batch["last_actions"]).to(args.device)
            policy_hidden = torch.from_numpy(batch["policy_hidden"]).to(args.device)
            self.meta_policy.reset()
            current_nonterm = np.ones([len(obs), 1], dtype=bool)
            max_length = 0
            with torch.no_grad():
                for h in range(args["Candidate"]["candidate_horizon"]):
                    batch_size = obs.shape[0]
                    with torch.no_grad():
                        action, _, mu, logstd, policy_hidden_next = self.meta_policy.get_action(obs, lst_action,
                                                                                                policy_hidden,
                                                                                                deterministic=False,
                                                                                                out_mean_std=True)

                    next_obs_dists = meta_dynamics(torch.cat([obs, action], dim=-1),
                                                       use_res=True)  # 这里得到的是一个分布

                    next_sample = next_obs_dists.sample()
                    next_obs = next_sample[:, :, :-1] + obs
                    next_obs = next_obs[idx, np.arange(batch_size)]

                    term = is_terminal(obs.detach().cpu().numpy(), action.detach().cpu().numpy(),
                                       next_obs.detach().cpu().numpy(),
                                       args["task"])

                    nonterm_mask = ~term
                    current_nonterm = current_nonterm & nonterm_mask

                    single_obs.append(obs.detach().cpu())
                    single_next_obs.append(next_obs.detach().cpu())
                    single_act.append(action.detach().cpu())

                    obs = next_obs.detach()
                    lst_action = action
                    policy_hidden = policy_hidden_next

                    max_length += 1
                    if (current_nonterm).sum() <= 0:
                        break
                single_obs = torch.stack(single_obs, dim=0)
                single_act = torch.stack(single_act, dim=0)
                single_next_obs = torch.stack(single_next_obs, dim=0)
                total_data[idx] = {
                    "obs": single_obs,
                    "act": single_act,
                    "next_obs": single_next_obs
                }

        if not os.path.exists(path):
            os.makedirs(path, exist_ok=True)
        torch.save(total_data, os.path.join(path, "model_data_{}_{}.pt".format(args["Candidate"]["train_epoch"], self.candidate_num)))


    def eval_env_cover(self, env=None):
        env = get_env(self.args["task"]) if env is None else env
        real_obs = []
        real_act = []
        real_obs_next = []
        with torch.no_grad():
            for _ in range(self.args["Eval"]['num_traj']):
                state, done = env.reset(), False
                self.meta_policy.reset()
                lst_action = torch.zeros((1, 1, self.args['action_shape'])).to(self.device)
                hidden_policy = torch.zeros((1, 1, self.args['rnn_hidden_dim'])).to(self.device)
                while not done:
                    state = state[np.newaxis]  # 这里增加了数据的维度，当做batch为1在处理
                    state = torch.from_numpy(state).float().to(self.device)
                    action, _, hidden_policy = self.meta_policy.get_action(state, lst_action, hidden_policy,
                                                                           deterministic=True)
                    assert _ is None
                    use_action = action.cpu().numpy().reshape(-1)
                    state_next, reward, done, _ = env.step(use_action)
                    lst_action = action
                    real_obs.append(state)
                    real_act.append(action)
                    real_obs_next.append(torch.from_numpy(state_next[np.newaxis]).float().to(self.device))
                    state = state_next

            real_obs = torch.stack(real_obs, dim=1).squeeze(0)
            real_act = torch.stack(real_act, dim=1).squeeze(0)
            real_obs_next = torch.stack(real_obs_next, dim=1)

            assert len(real_obs.shape) == len(real_act.shape) == 2

            # eval
            dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=use_tanh).to(self.device)

            next_obs_dists = dynamics(torch.cat([real_obs, real_act], dim=-1))

            next_sample = next_obs_dists.sample()
            next_obs = next_sample[:, :, :-1]

            res = torch.pow(next_obs - real_obs_next, 2).mean(dim=-1)
            res_max = res.max(dim=0)[0]
            loss_max = res_max.mean()

            res_min = res.min(dim=0)[0]
            loss_min = res_min.mean()

            res_ind = res.mean(dim=-1)
            loss_ind_max = res_ind.max(dim=0)[0]
            loss_ind_min = res_ind.min(dim=0)[0]

        return loss_max, loss_min, loss_ind_max, loss_ind_min

    def train_meta_policy(self, path):
        # self.logger.log_str(f"Start to train meta policy ...", type="WARNING")
        args = self.args

        meta_dynamics = ParallelRDynamics.from_single_transition(self.candidate_model, use_tanh=use_tanh).to(self.device)
        model_pool = SimpleReplayTrajPool(args.obs_space, args.action_space, args.horizon, args.rnn_hidden_dim,
                                          args.Meta.model_pool_size)
        # old_policy = RNNSACAgent(args).to(self.device)
        # old_policy.load_state_dict(self.meta_policy.state_dict())
        torch.cuda.empty_cache()

        batch_size = args["Meta"]["rollout_batch_size"]
        for i_epoch in range(1, args["Meta"]["train_epoch"] + 1):
            rollout_res = policy_rollout(args, self.meta_policy, meta_dynamics, self.traj_env_pool, model_pool,
                                         batch_size,
                                         deterministic=False)

            train_loss = dict()
            for j in range(args["Meta"]["train_update"]):
                ratio = None if not self.init_flag else 1.0
                bc = True if self.init_flag else False
                batch = self.get_train_policy_batch(self.traj_env_pool, model_pool, args["Meta"]["train_batch_size"], ratio=ratio)
                # cons_batch = self.get_train_policy_batch(self.traj_env_pool, model_pool, args["Meta"]["train_batch_size"], ratio=1.0)
                train_res = self.meta_policy.train_policy(batch,
                                                          sac_embedding_infer=args["ablation"]["sac_embedding_infer"],
                                                          behavior_cloning=bc,
                                                          cons_policy=None,  # self.bc_policy
                                                          cons_batch=None)
                for _key in train_res:
                    train_loss[_key] = train_loss.get(_key, 0) + train_res[_key]
            for _key in train_loss:
                train_loss[_key] = train_loss[_key] / args["Meta"]["train_update"]

            for t in self.type_list:
                self.type = t
                for param in self.param_range[self.type]:
                    env = get_env(self.args["task"])
                    # env.reset_grav(param)
                    reset_param(env, self.type, param)
                    eval_res = self.meta_policy.eval_policy(env=env)
                    print(eval_res)
                    self.logger.log_scalars("Eval_{}_{}".format(self.type, str(param)), eval_res,
                                            step=i_epoch + self.meta_epoch)
            train_loss.update(rollout_res)

            self.logger.log_scalars("Meta", train_loss, step=i_epoch + self.meta_epoch)

            if (i_epoch + self.meta_epoch) % args["Meta"]["reset_interval"] == 0:
                loader.reset_hidden_state(self.traj_env_pool, args["data_name"], maxlen=args.horizon,
                                          policy_hook=self.meta_policy.policy_gru,
                                          value_hook=self.meta_policy.value_gru,
                                          device=args.device)
            torch.cuda.empty_cache()

            # if (i_epoch + self.meta_epoch) % args["Meta"]["save_interval"] == 0:
            #     self.meta_policy.save(os.path.join(path, str(i_epoch+self.meta_epoch)))

        self.meta_epoch += args["Meta"]["train_epoch"]
        # self.load_meta_policy(path)

    # def load_meta_policy(self, path):
    #     checkpoints = [int(i) for i in os.listdir(path)]
    #     self.logger.log_str(f"Loading meta policy from {path}, found checkpoints {checkpoints}")
    #     last_checkpoint = max(checkpoints)
    #     self.meta_policy = RNNSACAgent(self.args)
    #     self.meta_policy.load(os.path.join(path, str(last_checkpoint)))

    #     del self.meta_dynamics
    #     del self.meta_agent
    #     torch.cuda.empty_cache()

    def _eval_dynamics(self, dynamics, valdata, inc_var_loss=True):
        with torch.no_grad():
            valdata.to_torch(device=self.device)
            # valdata['obs'] = valdata['obs']
            # valdata['act'] = valdata['act']
            # valdata['obs_next'] = valdata['obs_next']
            # valdata['rew'] = valdata['rew']
            dist = dynamics(torch.cat([valdata['obs'], valdata['act']], dim=-1))
            # temp = ((dist.mean - torch.cat([valdata['obs_next'], valdata['rew']], dim=-1)) ** 2)
            if inc_var_loss:
                mse_losses = ((dist.mean - torch.cat([valdata['obs_next'], valdata['rew']], dim=-1)) ** 2 / (
                        dist.variance + 1e-8)).mean(dim=(1, 2))
                logvar = dist.scale.log()
                logvar = 2 * dynamics.max_logstd - torch.nn.functional.softplus(
                    2 * dynamics.max_logstd - dist.variance.log())
                logvar = 2 * dynamics.min_logstd + torch.nn.functional.softplus(logvar - 2 * dynamics.min_logstd)
                var_losses = logvar.mean(dim=(1, 2))
                loss = mse_losses + var_losses
            else:
                loss = ((dist.mean - torch.cat([valdata['obs_next'], valdata['rew']], dim=-1)) ** 2).mean(dim=(1, 2))
            return loss

    def get_train_policy_batch(self, env_pool, model_pool, batch_size, ratio=None):
        ratio = ratio if ratio is not None else self.args["real_data_ratio"]
        batch_size = batch_size
        env_batch_size = int(batch_size * ratio)
        model_batch_size = batch_size - env_batch_size

        env_batch = env_pool.random_batch(env_batch_size)
        if model_batch_size > 0:
            model_batch = model_pool.random_batch(model_batch_size)
            keys = set(env_batch.keys()) & set(model_batch.keys())
            batch = {k: np.concatenate((env_batch[k], model_batch[k])) for k in keys}
        else:
            batch = env_batch
        return batch