# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import time
import os
from collections import deque
import statistics

# from torch.utils.tensorboard import SummaryWriter
import torch
import torch.optim as optim
import wandb
# import ml_runlog
import datetime
import torch
from collections import deque
import random
import numpy as np


from rsl_rl.algorithms import PPO
from rsl_rl.modules import *
from rsl_rl.env import VecEnv
import sys
from copy import copy, deepcopy
import warnings
from rsl_rl.utils.utils import Normalizer

import torch.nn as nn
import torch.optim as optim

from rsl_rl.storage import RolloutStorage, ReplayBuffer
import torch.nn.functional as F



# Replay Buffer
class ReplayBuffer_selector:
    def __init__(self, capacity, device):
        self.buffer = deque(maxlen=capacity)
        self.device = device

    # def add(self, state, action, reward, next_state, done):
        # self.buffer.append((state, action, reward, next_state, done))
    
    def add(self, state, action, reward, next_state, done):
        # self.buffer.append((
        #     np.array(state), 
        #     np.array(action), 
        #     np.array(reward), 
        #     np.array(next_state), 
        #     np.array(done)
        # ))
        self.buffer.append((
            np.asarray(state), 
            np.asarray(action), 
            np.asarray(reward), 
            np.asarray(next_state), 
            np.asarray(done)
        ))

    def sample(self, batch_size):
        # 从buffer中随机采样
        indices = np.random.choice(len(self.buffer), batch_size, replace=False)
        sampled_data = [self.buffer[i] for i in indices]

        # 解包数据
        # states, actions, rewards, next_states, dones = zip(*sampled_data)
        
        # 转换为 NumPy 数组
        states, actions, rewards, next_states, dones = map(np.array, zip(*sampled_data))


        # 转换为张量，确保形状正确
        return (
            torch.tensor(states).to(self.device),  # 确保为二维数组
            torch.tensor(actions).to(self.device),             # 动作为一维数组
            torch.tensor(rewards).to(self.device),          # 奖励为一维数组
            torch.tensor(next_states).to(self.device),  # 确保为二维数组
            torch.tensor(dones).to(self.device),            # 完成标志为一维数组
        )


    def __len__(self):
        return len(self.buffer)


# Reward Normalizer
class RewardNormalizer:
    def __init__(self, gamma=0.99):
        self.gamma = gamma
        self.running_mean = 0
        self.var = 1
        self.count = 1e-4

    def normalize(self, reward):
        self.running_mean = self.gamma * self.running_mean + (1 - self.gamma) * reward
        self.var = self.gamma * self.var + (1 - self.gamma) * (reward - self.running_mean) ** 2
        self.count += 1
        std = (self.var / self.count) ** 0.5
        return (reward - self.running_mean) / (std + 1e-8)


# class SelectorNetwork(nn.Module):
#     def __init__(self, input_dim):
#         super(SelectorNetwork, self).__init__()
#         # 增加网络深度
#         self.fc1 = nn.Linear(input_dim, 256)  # 增大第一层
#         self.fc2 = nn.Linear(256, 128)
#         self.fc3 = nn.Linear(128, 64)
#         self.fc4 = nn.Linear(64, 2)  # 输出 Q 值，每个动作一个

#     def forward(self, x):
#         x = torch.relu(self.fc1(x))
#         x = torch.relu(self.fc2(x))
#         x = torch.relu(self.fc3(x))
#         return self.fc4(x)  # 输出 Q 值，未受激活函数限制


class SelectorNetwork(nn.Module):
    def __init__(self, input_dim, output_dim=2, dropout_prob=0.1):
        super(SelectorNetwork, self).__init__()
        # 定义网络层
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 64)
        self.fc4 = nn.Linear(64, output_dim)  # 输出 Q 值，每个动作一个
        
        # Dropout 层（防止过拟合）
        self.dropout = nn.Dropout(p=dropout_prob)
        
        # Layer Normalization
        self.ln1 = nn.LayerNorm(256)
        self.ln2 = nn.LayerNorm(128)
        self.ln3 = nn.LayerNorm(64)
        
        # 权重初始化
        self._init_weights()

    def _init_weights(self):
        """权重初始化"""
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.zeros_(self.fc1.bias)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.zeros_(self.fc2.bias)
        nn.init.xavier_uniform_(self.fc3.weight)
        nn.init.zeros_(self.fc3.bias)
        nn.init.xavier_uniform_(self.fc4.weight)
        nn.init.zeros_(self.fc4.bias)

    def forward(self, x):
        # 前向传播
        x = F.relu(self.ln1(self.fc1(x)))  # 第一层，带 LayerNorm 和 ReLU
        x = self.dropout(x)                # Dropout
        x = F.relu(self.ln2(self.fc2(x)))  # 第二层，带 LayerNorm 和 ReLU
        x = self.dropout(x)                # Dropout
        x = F.relu(self.ln3(self.fc3(x)))  # 第三层，带 LayerNorm 和 ReLU
        x = self.fc4(x)                    # 输出层，不加激活函数
        return x


# RL Trainer
class Selector_Trainer_AMP:
    def __init__(self,
            env: VecEnv,
            train_cfg,
            log_dir=None,
            init_wandb=True,
            device='cpu', **kwargs):

        self.env = env
        self.device = device
        self.batch_size = 128
        self.lr = 1e-4
        self.gamma = 0.99
        self.max_grad_norm = 1.0
        self.replay_buffer = ReplayBuffer_selector(capacity=1000000, device=device)
        self.reward_normalizer = RewardNormalizer(gamma=self.gamma)

        self.prev_action = None  # 用于记录上一次的选择
        self.switch_penalty = 0.01  # 定义切换惩罚的系数      Is it too small?   0.001


        # Selector network and optimizer
        self.policy_cfg = train_cfg["policy"]
        self.estimator_cfg = train_cfg["estimator"]

        # obs_dim = self.policy_cfg['selector_input'] + 1
        obs_dim = self.policy_cfg['selector_input']
        self.selector = SelectorNetwork(input_dim=obs_dim).to(device)
        self.target_selector = deepcopy(self.selector).to(device)      # 目标网络
        self.target_selector.eval()  # 目标网络只推理，不训练
        self.target_update_frequency = 50  # 目标网络更新频率（单位：步数）

        self.optimizer = optim.Adam(self.selector.parameters(), lr=self.lr)
        self.loss_fn = nn.MSELoss()  
        
        # ε-greedy 参数
        self.epsilon = 0.05  # 初始 ε
        self.epsilon_min = 0.0001  # 最小 ε
        self.epsilon_decay = 0.995  # 衰减率




        

        # Placeholder for policies
        # self.locomotion_policy = self.load_policy('locomotion_policy.pt')
        # self.recovery_policy = self.load_policy('recovery_policy.pt')
        # loco_path = '/data1/selector_policy/change/walk_terrain_test7-7800-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/walk_recovery9-6400-actor_jit.pt'
        # 12.4        
        # loco_path = '/data1/selector_policy/change/12_4/walk_compare_zmp_slip2-9800-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/12_4/walk_recovery9-9800-actor_jit.pt'
        
        # # deploy        
        # loco_path = '/data1/selector_policy/change/deploy/walk_locomotion_deploy-14800-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/deploy/walk_recovery_deploy-14800-actor_jit.pt'


        # # deploy2        
        # loco_path = '/data1/selector_policy/change/deploy2/walk_locomotion_deploy2-15800-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/deploy2/walk_recovery_deploy2-15800-actor_jit.pt'


        # # deploy3        
        # loco_path = '/data1/selector_policy/change/deploy3/walk_locomotion_deploy3-15800-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/deploy3/walk_recovery_deploy3-12800-actor_jit.pt'

        # # deploy5        
        # loco_path = '/data1/selector_policy/change/deploy5/walk_locomotion_deploy5-14200-actor_jit.pt'
        # reco_path = '/data1/selector_policy/change/deploy5/walk_recovery_deploy5-15800-actor_jit.pt'

        root_path = "/data1/selector_policy/test/1"
        self.model_save_path = "/data1/selector_policy/test/1"

        loco_name = "walk_locomotion_test8-19800-actor_jit.pt"
        reco_name = "walk_recovery_test8-19800-actor_jit.pt"

        loco_path = os.path.join(root_path, loco_name)
        reco_path = os.path.join(root_path, reco_name)

        self.locomotion_policy = torch.jit.load(loco_path, map_location=env.device)
        self.recovery_policy =  torch.jit.load(reco_path, map_location=env.device)

        self.amp_cfg = train_cfg["amp"]

        amp_discriminator = AMPDiscriminator(
            self.amp_cfg['amp_input_dim'],
            self.amp_cfg['amp_disc_hidden_dims'], device,).to(self.device)

        self.amp_discriminator = amp_discriminator
        self.amp_discriminator.to(self.device)
        # self.amp_transition = RolloutStorage.Transition()
        print(self.amp_discriminator.input_dim)
        print(self.amp_cfg["amp_replay_buffer_size"])
        print(self.device)
        print(ReplayBuffer)  # 输出 ReplayBuffer 的定义路径



        self.amp_storage = ReplayBuffer(self.amp_discriminator.input_dim, self.amp_cfg["amp_replay_buffer_size"], self.device)
        self.amp_demo_storage = ReplayBuffer(self.amp_discriminator.input_dim, self.amp_cfg["amp_demo_buffer_size"], self.device)
        self.amp_fetch_demo_batch_size = self.amp_cfg["amp_demo_fetch_batch_size"]
        self.amp_learn_batch_size = self.amp_cfg["amp_learn_batch_size"]
        self.init_amp_demo_buf()
        
        self.amp_normalizer = None
        self.amp_optimizer = optim.Adam(self.amp_discriminator.parameters(), lr=self.amp_cfg["amp_learning_rate"], weight_decay=1e-3)
        self.amp_rew_scale = self.amp_cfg["amp_reward_coef"]
        self.amp_grad_pen = self.amp_cfg["amp_grad_pen"]



    def load_policy(self, policy_path):
        # Dummy policy for demonstration
        return lambda obs: torch.tanh(obs.sum(dim=1, keepdim=True))  # Replace with actual policy loading logic

    def update_selector(self, update_step):
        if len(self.replay_buffer) < self.batch_size:
            return

        # Sample from the replay buffer
        states, actions, rewards, next_states, dones = self.replay_buffer.sample(self.batch_size)
        # print('states', states.shape)
        # Compute Q(s, a) using main network
        q_values = self.selector(states)  # 主网络输出 Q 值
        q_values = q_values.gather(1, actions.unsqueeze(-1).long()).squeeze(-1)  # 提取选择动作的 Q 值

        # Compute Q_target using target network
        with torch.no_grad():
            next_q_values = self.target_selector(next_states)  # 目标网络计算 Q 值
            max_next_q_values = next_q_values.max(dim=1)[0]  # 针对下一状态选择最大 Q 值
            q_targets = rewards + self.gamma * max_next_q_values * (1 - dones.float())  # Bellman 方程

        # Compute loss
        loss = self.loss_fn(q_values, q_targets)

        # Backward pass and optimizer step
        self.optimizer.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(self.selector.parameters(), 1.0)  # 梯度裁剪
        self.optimizer.step()

        # Update target network every few steps
        if update_step % self.target_update_frequency == 0:
            self.target_selector.load_state_dict(self.selector.state_dict())  # 同步参数

        return loss


    def select_action(self, q_values):
        # ε-greedy 策略
        if random.random() < self.epsilon:
            # print(torch.randint(0, 2, (q_values.shape[0]), device=self.device).long().shape)
            # print('random', torch.randint(0, 2, (q_values.shape[0],), device=self.device).long())
            return torch.randint(0, 2, (q_values.shape[0],), device=self.device).long()
        else:
            # print(q_values.argmax(dim=1).long().shape)
            return q_values.argmax(dim=1).long()  # 选择具有最大 Q 值的动作


    def save_selector(self, filename):
        """将 Selector Network 保存为 state_dict 格式"""
        # 获取模型的状态字典
        model_state_dict = self.selector.state_dict()
        
        # 保存为普通的字典格式
        filename = os.path.join(self.model_save_path, filename)
        torch.save(model_state_dict, filename)


    def load_selector(self, filename):
        self.selector = torch.load(self.model_save_path, filename).to(self.device)



    def learn(self, num_learning_iterations, num_steps_per_env, init_at_random_ep_len=False):
        if init_at_random_ep_len:
            self.env.episode_length_buf = torch.randint_like(
                self.env.episode_length_buf, high=int(self.env.max_episode_length)
            )

        cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_amp_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        cur_switch_penalty_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        rewbuffer = deque(maxlen=100)
        lenbuffer = deque(maxlen=100)
        amp_rew_scaled_buf = deque(maxlen=100)
        switch_penalty_buf = deque(maxlen=100)
        
        # selected_action_history = torch.zeros(self.env.num_envs, 10 , device=self.device)

        frequence_period = 100

        locomotion_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)
        recovery_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device)

        obs = self.env.get_observations()
        self.locomotion_policy 
        z, v = self.locomotion_policy.estimator(obs.detach()[:, self.estimator_cfg['prop_start'] - 3 * self.estimator_cfg['prop_dim'] : self.estimator_cfg['prop_start']])
        latent = torch.cat([z,v],dim = 1)

            
        obs = torch.cat([obs[:, :self.estimator_cfg['prop_start'] + self.estimator_cfg['prop_dim']], latent],dim = 1)
        self.prev_action = torch.zeros(self.env.num_envs, device=self.device)
        # obs_h = torch.cat((obs, self.prev_action.unsqueeze(1)), dim=-1)

        ep_infos = []

        for it in range(num_learning_iterations):
            locomotion_percentages = []  # 存储每个 episode 中 locomotion 策略选择的百分比
            with torch.inference_mode():
                for step in range(num_steps_per_env):
                    # obs_h = torch.cat((obs, selected_action_history), dim = -1)
                    with torch.no_grad():
                        # Q-values for both actions (0: locomotion, 1: recovery)
                        q_values = self.selector(obs)
                        # q_values = self.selector(obs_h)
                        selected_action = self.select_action(q_values)  # 使用 ε-greedy 选择动作

                    # 更新selected_action时计算切换惩罚


                    switch_penalty = - (self.prev_action != selected_action).float() * self.switch_penalty
                    # print("self.prev_action", self.prev_action.shape)  # [env_nums]

                    # Execute the selected policy for all environments
                    actions = torch.empty((obs.shape[0], 19), device=self.device)  # Initialize action tensor

                    locomotion_mask = selected_action == 0  # Mask for locomotion policy
                    recovery_mask = selected_action == 1  # Mask for recovery policy

                    # Compute actions for each policy
                    if locomotion_mask.any():
                        actions[locomotion_mask] = self.locomotion_policy(obs[locomotion_mask])
                        z, v = self.locomotion_policy.estimator(obs.detach()[:, self.estimator_cfg['prop_start'] - 2 * self.estimator_cfg['prop_dim'] : self.estimator_cfg['prop_start'] + self.estimator_cfg['prop_dim']])
                        loco_latent = torch.cat([z,v],dim = 1)
                        latent[locomotion_mask] = loco_latent[locomotion_mask]
                        locomotion_sum[locomotion_mask] +=1
                    if recovery_mask.any():
                        actions[recovery_mask] = self.recovery_policy(obs[recovery_mask])
                        z, v = self.recovery_policy.estimator(obs.detach()[:, self.estimator_cfg['prop_start'] - 2 * self.estimator_cfg['prop_dim'] : self.estimator_cfg['prop_start'] + self.estimator_cfg['prop_dim']])
                        reco_latent = torch.cat([z,v],dim = 1)
                        latent[recovery_mask] = reco_latent[recovery_mask]
                        recovery_sum[recovery_mask] +=1

                    # Interact with the environment
                    next_obs, privileged_obs, reward, dones, info = self.env.step(actions.to(self.device))
                    # next_obs = torch.tensor(next_obs, dtype=torch.float32).to(self.device)

                    next_obs = torch.cat([next_obs[:, :self.estimator_cfg['prop_start'] + self.estimator_cfg['prop_dim']], latent],dim = 1)
                    # next_obs_h = torch.cat((next_obs, self.prev_action.unsqueeze(1)), dim = -1)

                    # Normalize rewards
                    # normalized_reward = self.reward_normalizer.normalize(reward)
                    amp_rew_scaled = self.amp_rew_scale * 0.02 * self.calc_amp_rewards(info["amp_obs"]).squeeze(1)

                    self.amp_storage.insert(info["amp_obs"])
                    # print('switch_penalty', switch_penalty.shape)
                    normalized_reward = reward + amp_rew_scaled + switch_penalty

                    if 'episode' in info:
                        ep_infos.append(info['episode'])

                    # print('obs', obs.shape)
                    cur_reward_sum += normalized_reward
                    cur_amp_reward_sum += amp_rew_scaled
                    cur_switch_penalty_sum += switch_penalty
                    cur_episode_length +=1
                    new_ids = (dones > 0).nonzero(as_tuple=False)

                    rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                    lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist())
                    amp_rew_scaled_buf.extend(cur_amp_reward_sum[new_ids][:, 0].cpu().numpy().tolist())
                    switch_penalty_buf.extend(cur_switch_penalty_sum[new_ids][:, 0].cpu().numpy().tolist())

                    cur_reward_sum[new_ids] = 0
                    cur_amp_reward_sum[new_ids] = 0
                    cur_switch_penalty_sum[new_ids] = 0
                    cur_episode_length[new_ids] = 0
                    locomotion_sum[new_ids] = 0
                    recovery_sum[new_ids] = 0

                    # for i in range(obs.shape[0]):  # 遍历每个环境的状态
                    #     self.replay_buffer.add(
                    #         obs_h[i].detach().cpu(),                  # 单个环境的状态
                    #         selected_action[i].detach().cpu(),      # 对应的动作
                    #         normalized_reward[i].detach().cpu(),               # 对应的奖励
                    #         next_obs_h[i].detach().cpu(),             # 下一状态
                    #         dones[i].detach().cpu()                 # 完成标志
                    #     )
                    for i in range(obs.shape[0]):  # 遍历每个环境的状态
                        self.replay_buffer.add(
                            obs[i].detach().cpu(),                  # 单个环境的状态
                            selected_action[i].detach().cpu(),      # 对应的动作
                            normalized_reward[i].detach().cpu(),               # 对应的奖励
                            next_obs[i].detach().cpu(),             # 下一状态
                            dones[i].detach().cpu()                 # 完成标志
                        )
                    
                    obs = next_obs
                    # obs_h = next_obs_h
                    # 更新prev_action
                    self.prev_action = selected_action.clone()
                    

            self.update_amp_demos()
            loco_percentage = (locomotion_sum / (locomotion_sum + recovery_sum + 1e-6)).mean()
            reco_percentage = (recovery_sum / (locomotion_sum + recovery_sum + 1e-6)).mean()

            # Update selector
            for i in range(10):
                dqn_loss = self.update_selector(update_step = it)


            amp_loss, grad_pen_loss, policy_d, expert_d = self.update_amp()

            # 衰减 ε
            if self.epsilon > self.epsilon_min:
                self.epsilon *= self.epsilon_decay


            # Logging
            # 打印本次迭代的 locomotion 策略选择的平均百分比
            # print(dqn_loss)
            
            print(f"Iteration {it}, Avg Locomotion Percentage: {loco_percentage * 100:.2f}%")
            print(f"Iteration {it}, Avg recovery Percentage: {reco_percentage * 100:.2f}%")
            print(f"Iteration {it}, Replay Buffer Size: {len(self.replay_buffer)}")
            print('Current mean reward', cur_reward_sum.mean())

            print('Current mean episode length', cur_episode_length.mean())
            print('Current Style reward', cur_amp_reward_sum.mean())
            print(f"Q-Learning Loss: {dqn_loss.item()}")


            locs = locals()
            wandb_dict = {
                "Iteration": it,
                "DQN_Loss": dqn_loss.item(),
                "Locomotion_Percentage": loco_percentage.item(),
                "Recovery_Percentage": reco_percentage.item(),
                "Replay_Buffer_Size": len(self.replay_buffer),
                "Mean_Reward": torch.mean(torch.tensor(rewbuffer, dtype=torch.float32, device=self.device)).item(),
                "Mean_Episode_Length": torch.mean(torch.tensor(lenbuffer, dtype=torch.float32, device=self.device)).item(),
                "AMP Loss": amp_loss.clone().detach().item(),
                "Style_reward": torch.mean(torch.tensor(amp_rew_scaled_buf, dtype=torch.float32, device=self.device)).item(),
                "Switch_penalty": torch.mean(torch.tensor(switch_penalty_buf, dtype=torch.float32, device=self.device)).item(),
                'Epsilon': self.epsilon
            }


            if locs.get('ep_infos') and len(locs['ep_infos']) > 0:  # 确保 ep_infos 存在且非空
                for key in locs['ep_infos'][0]:
                    infotensor = torch.empty(0, device=self.device)  # 初始化空张量
                    for ep_info in locs['ep_infos']:
                        # 将 ep_info[key] 转换为张量，处理标量和零维张量情况
                        value_tensor = torch.as_tensor(ep_info[key], device=self.device, dtype=torch.float32).flatten()
                        infotensor = torch.cat((infotensor, value_tensor))  # 拼接张量
                    value = infotensor.mean()  # 计算均值
                    # 根据 key 名称分类并记录到 wandb_dict
                    if "tracking" in key:
                        wandb_dict[f"Episode_rew_tracking/{key}"] = value
                    elif "curriculum" in key:
                        wandb_dict[f"Episode_curriculum/{key}"] = value
                    elif "terrain_level" in key:
                        wandb_dict[f"Episode_terrain_level/{key}"] = value
                    else:
                        wandb_dict[f"Episode_rew_regularization/{key}"] = value
    
            wandb.log(wandb_dict)


            # 每个学习迭代结束后保存 selector 模型
            if it < 2000:
                if (it + 1) % 500 == 0:  # 每10次迭代保存一次模型
                    # path = "/data1/selector_policy"
                    selector_filename = f"selector_model_{it + 1}.pt"
                    self.save_selector(selector_filename)
                    # self.load_selector(selector_filename)  # 加载最新的模型
                    print(f"Selector model saved to {selector_filename}")
            else:
                if (it + 1) % 200 == 0:  # 每10次迭代保存一次模型
                    # path = "/data1/selector_policy"
                    selector_filename = f"selector_model_{it + 1}.pt"
                    self.save_selector(selector_filename)
                    # self.load_selector(selector_filename)  # 加载最新的模型
                    print(f"Selector model saved to {selector_filename}")
            
            ep_infos.clear()



    def init_amp_demo_buf(self):
        buffer_size = self.amp_demo_storage.buffer_size
        num_batches = int(np.ceil(buffer_size / self.amp_fetch_demo_batch_size))

        for i in range(num_batches):
            curr_samples = self.env.fetch_amp_obs_demo(self.amp_fetch_demo_batch_size)
            self.amp_demo_storage.insert(curr_samples)
        return
    
    def update_amp_demos(self):
        curr_samples = self.env.fetch_amp_obs_demo(self.amp_fetch_demo_batch_size)
        self.amp_demo_storage.insert(curr_samples)
        return

    def calc_amp_rewards(self, amp_obs):
        with torch.no_grad():
            disc_logits = self.amp_discriminator(amp_obs)
            # prob = 1 / (1 + torch.exp(-disc_logits)) 
            # disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device)))
            disc_r = torch.clamp(1 - (1/4) * torch.square(disc_logits - 1), min=0)
        return disc_r


    def update_amp(self):
        amp_policy_generator = self.amp_storage.feed_forward_generator(1,self.amp_learn_batch_size)
        amp_demo_generator = self.amp_demo_storage.feed_forward_generator(1,self.amp_learn_batch_size)
        sample_amp_policy, sample_amp_demo = next(amp_policy_generator), next(amp_demo_generator)
        if self.amp_normalizer is not None:
            with torch.no_grad():
                sample_amp_policy = self.amp_normalizer.normalize_torch(sample_amp_policy, self.device)
                sample_amp_demo = self.amp_normalizer.normalize_torch(sample_amp_demo, self.device)
        policy_d = self.amp_discriminator(sample_amp_policy)
        expert_d = self.amp_discriminator(sample_amp_demo)
        
        # # Original AMP Loss
        # expert_loss = torch.nn.MSELoss()(expert_d, torch.ones(expert_d.size(), device=self.device))
        # policy_loss = torch.nn.MSELoss()(policy_d, -1 * torch.ones(policy_d.size(), device=self.device))
        # amp_loss = 0.5 * (expert_loss + policy_loss)

        # Wasserstein-1 距离
        amp_loss = -torch.mean(expert_d) + torch.mean(policy_d)


        grad_pen_loss = self.amp_discriminator.compute_grad_pen(sample_amp_demo, lambda_=self.amp_grad_pen)
        amp_loss_pen = amp_loss + grad_pen_loss
        self.amp_optimizer.zero_grad()
        amp_loss_pen.backward()
        nn.utils.clip_grad_norm_(self.amp_discriminator.parameters(), self.max_grad_norm)
        self.amp_optimizer.step()
        if self.amp_normalizer is not None:
            self.amp_normalizer.update(sample_amp_policy.cpu().numpy())
            self.amp_normalizer.update(sample_amp_demo.cpu().numpy())
        return amp_loss, grad_pen_loss, policy_d.mean(), expert_d.mean()