import copy
from components.episode_buffer import EpisodeBatch
from modules.mixers.nmix import Mixer
from modules.mixers.vdn import VDNMixer
from modules.mixers.qatten import QattenMixer
from envs.matrix_game import print_matrix_status
from utils.rl_utils import build_td_lambda_targets, build_q_lambda_targets
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import RMSprop, Adam
import numpy as np
from utils.th_utils import get_parameters_num
from typing import Optional, Union, Dict, Type, Tuple
from utils.ensembles import EnsembleMLP, Normalizer,EPS
from utils.info_utils import BaseIntrinsicReward,DisagreementIntrinsicReward,show_Visitation_heatmaps


class EECENQLearner:
    def __init__(self, mac, scheme, logger, args, env_info):
        self.args = args
        self.mac = mac
        self.logger = logger

        self.params = list(mac.parameters())

        self.last_target_update_episode = 0
        self.n_agents = args.n_agents

        self.device = th.device("cuda" if args.use_cuda else "cpu")

        if args.mixer == "qatten":
            self.mixer = QattenMixer(args)
        elif args.mixer == "vdn":
            self.mixer = VDNMixer()
        elif args.mixer == "qmix":
            self.mixer = Mixer(args)
        else:
            raise "mixer error"
        
        if args.int_mixer == "qatten":
            self.int_mixer = QattenMixer(args)
        elif args.int_mixer == "vdn":
            self.int_mixer = VDNMixer()
        elif args.int_mixer == "qmix":
            self.int_mixer = Mixer(args)
        else:
            raise "int_mixer error"
        
        self.target_mixer = copy.deepcopy(self.mixer)
        self.int_target_mixer = copy.deepcopy(self.int_mixer)
        self.params += list(self.mixer.parameters())
        self.int_params = list(mac.int_parameters())
        self.int_params += list(self.int_mixer.parameters())


        print("Mixer Size: ")
        print(get_parameters_num(self.mixer.parameters()))

        self.normalize_ensemble_training = True
        self.pred_diff = True
        self.learn_rewards = True
        intrinsic_reward_weights = None
        # int_reward = DisagreementIntrinsicReward

        ensemble_model_kwargs = {
            'learn_std': args.learn_std,
            'features': (args.features, args.features),
            'optimizer_kwargs': {'lr': args.ensemble_model_lr, 'weight_decay': args.weight_decay},
            'use_entropy': args.use_entropy,
        }

        state_dim = int(np.prod(self.args.state_shape))
        self.input_dim = state_dim + self.args.action_embed_dim * self.args.n_agents
        # self.action_embed_net = nn.Embedding(self.args.n_actions, self.args.action_embed_dim)

        output_dict = self._get_ensemble_targets(th.zeros((1,state_dim)),
                                                 th.zeros((1,state_dim)),
                                                 rewards=th.zeros((1, 1))
                                                 )
        
        self.input_normalizer = Normalizer(input_dim=self.input_dim, update=self.normalize_ensemble_training,
                                           device=self.device)
        output_normalizers = {}
        for key, val in output_dict.items():
            output_normalizers[key] = Normalizer(input_dim=val.shape[-1], update=self.normalize_ensemble_training,
                                                 device=self.device)
        self.output_normalizers = output_normalizers

        self.entropy_normalizer = Normalizer(input_dim=1, update=self.normalize_ensemble_training,
                                             device=self.device)
        
        self.cooperative_normalizer = Normalizer(input_dim=1, update=self.normalize_ensemble_training,
                                             device=self.device)
        

        self.ensemble_model = EnsembleMLP(args=args,
            input_dim=self.input_dim,
            output_dict=output_dict,
            **ensemble_model_kwargs,
        )
        self.ensemble_model.to(self.device)

        self.target_ensemble_model = EnsembleMLP(args=args,
            input_dim=self.input_dim,
            output_dict=output_dict,
            **ensemble_model_kwargs,
        )
        self.target_ensemble_model.to(self.device)

        self.target_ensemble_model.load_state_dict(self.ensemble_model.state_dict())


        if intrinsic_reward_weights is not None:
            assert intrinsic_reward_weights.keys() == output_dict.keys()
        else:
            intrinsic_reward_weights = {k: 1.0 for k in output_dict.keys()}

        self.intrinsic_reward_model = DisagreementIntrinsicReward(
            intrinsic_reward_weights=intrinsic_reward_weights,
            ensemble_model=self.target_ensemble_model,
            agg_intrinsic_reward='mean',
        )

        if self.args.optimizer == 'adam':
            self.optimiser = Adam(params=self.params,  lr=args.lr, weight_decay=getattr(args, "weight_decay", 0))
        else:
            self.optimiser = RMSprop(params=self.params, lr=args.lr, alpha=args.optim_alpha, eps=args.optim_eps)

        if self.args.int_optimizer == 'adam':
            self.int_optimiser = Adam(params=self.int_params,  lr=args.int_lr, weight_decay=getattr(args, "weight_decay", 0))
        else:
            self.int_optimiser = RMSprop(params=self.int_params, lr=args.int_lr, alpha=args.optim_alpha, eps=args.optim_eps)

        # a little wasteful to deepcopy (e.g. duplicates action selector), but should work for any MAC
        self.target_mac = copy.deepcopy(mac)
        self.log_stats_t = -self.args.learner_log_interval - 1
        self.world_log_stats_t = -self.args.learner_log_interval - 1
        self.train_t = 0

        # priority replay
        self.use_per = getattr(self.args, 'use_per', False)
        self.return_priority = getattr(self.args, "return_priority", False)
        if self.use_per:
            self.priority_max = float('-inf')
            self.priority_min = float('inf')
    
    def get_intrinsic_reward(self, inp: th.Tensor, labels: Dict) -> th.Tensor:
        # calculate intrinsic reward
        entropy = self.intrinsic_reward_model(inp=inp, labels=labels)
        if not self.ensemble_model.learn_std:
            info_gain = entropy - th.log(th.ones_like(entropy) * EPS)
            return info_gain
        else:
            return entropy
        
    def _get_ensemble_targets(self, next_obs: Union[th.Tensor, Dict], obs: Union[th.Tensor, Dict],
                              rewards: th.Tensor) -> Dict:
        if self.pred_diff:
            if self.learn_rewards:
                return {
                    'next_obs': next_obs - obs,
                    'reward': rewards,
                }
            else:
                return {
                    'next_obs': next_obs - obs,
                }
        else:
            if self.learn_rewards:
                return {
                    'next_obs': next_obs,
                    'reward': rewards,
                }
            else:
                return {
                    'next_obs': next_obs,
                }
            

    def train_world(self, batch: EpisodeBatch, predict_epoch: int):
        '''
         :todo  mask_sample?
        '''
        s = batch["state"][:, :-1, :]
        next_s = batch["state"][:, 1:, :]
        bs, seq_len_cut, _ = s.shape
        rewards = batch["reward"][:, :-1].reshape(bs * seq_len_cut, -1)
       
        s = s.reshape(bs * seq_len_cut, -1)
        next_s = next_s.reshape(bs * seq_len_cut, -1)
        a = batch["actions"][:, :-1, :, 0].long()
        a_all = a[:, :seq_len_cut, :].reshape(bs * seq_len_cut, -1)

        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        filter_mask = mask.reshape(-1).bool()
        

        filter_s = s[filter_mask]
        filter_next_s = next_s[filter_mask]
        filter_a_all = a_all[filter_mask]
        filter_rewards = rewards[filter_mask]

        # print(s.shape,filter_s.shape)

        inp = th.cat([filter_s, self.ensemble_model.action_embed_net(filter_a_all).flatten(-2)], dim=-1)
        self.input_normalizer.update(inp.detach())
        labels = self._get_ensemble_targets(filter_next_s, filter_s,filter_rewards)
        for key, y in labels.items():
            self.output_normalizers[key].update(y)

        inp = self.input_normalizer.normalize(inp)
        for key, y in labels.items():
            labels[key] = self.output_normalizers[key].normalize(y)


        self.ensemble_model.train()
        for _ in range(predict_epoch):
            self.ensemble_model.optimizer.zero_grad()
            prediction = self.ensemble_model(inp)
            loss = self.ensemble_model.loss(prediction=prediction, target=labels)
            stacked_losses = []
            for key, val in loss.items():
                stacked_losses.append(val)
            stacked_losses = th.stack(stacked_losses)
            total_loss = stacked_losses.mean()
            total_loss.backward(retain_graph=True)
            self.ensemble_model.optimizer.step()

        self.ensemble_model.eval()
        self.ensemble_model.train(False)

        return total_loss.item()

    def train(
        self,
        batch: EpisodeBatch,
        t_env: int,
        episode_num: int,
        per_weight=None,
    ):
        # Get the relevant quantities
        rewards = batch["reward"][:, :-1]
        actions = batch["actions"][:, :-1]
        terminated = batch["terminated"][:, :-1].float()
        mask = batch["filled"][:, :-1].float()
        mask[:, 1:] = mask[:, 1:] * (1 - terminated[:, :-1])
        avail_actions = batch["avail_actions"]

        # Calculate estimated Q-Values
        self.mac.agent.train()
        mac_out = []
        self.mac.init_hidden(batch.batch_size)
        for t in range(batch.max_seq_length):
            agent_outs = self.mac.forward(batch, t=t)
            mac_out.append(agent_outs)
        mac_out = th.stack(mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        chosen_action_qvals = th.gather(mac_out[:, :-1], dim=3, index=actions).squeeze(
            3
        )  # Remove the last dim
        chosen_action_qvals_ = chosen_action_qvals

        # Calculate the Q-Values necessary for the target
        with th.no_grad():
            self.target_mac.agent.train()
            target_mac_out = []
            self.target_mac.init_hidden(batch.batch_size)
            for t in range(batch.max_seq_length):
                target_agent_outs = self.target_mac.forward(batch, t=t)
                target_mac_out.append(target_agent_outs)

            # We don't need the first timesteps Q-Value estimate for calculating targets
            target_mac_out = th.stack(target_mac_out, dim=1)  # Concat across time

            # Max over target Q-Values/ Double q learning
            mac_out_detach = mac_out.clone().detach()
            mac_out_detach[avail_actions == 0] = -9999999
            cur_max_actions = mac_out_detach.max(dim=3, keepdim=True)[1]
            target_max_qvals = th.gather(target_mac_out, 3, cur_max_actions).squeeze(3)

            # Calculate n-step Q-Learning targets
            target_max_qvals = self.target_mixer(target_max_qvals, batch["state"])

            rewards = rewards + self.args.step_penalty

            if getattr(self.args, "q_lambda", False):
                qvals = th.gather(target_mac_out, 3, batch["actions"]).squeeze(3)
                qvals = self.target_mixer(qvals, batch["state"])

                targets = build_q_lambda_targets(
                    rewards,
                    terminated,
                    mask,
                    target_max_qvals,
                    qvals,
                    self.args.gamma,
                    self.args.td_lambda,
                )
            else:
                targets = build_td_lambda_targets(
                    rewards,
                    terminated,
                    mask,
                    target_max_qvals,
                    self.args.n_agents,
                    self.args.gamma,
                    self.args.td_lambda,
                )

        # Mixer
        chosen_action_qvals = self.mixer(chosen_action_qvals, batch["state"][:, :-1])

        td_error = chosen_action_qvals - targets.detach()
        td_error2 = 0.5 * td_error.pow(2)

        mask = mask.expand_as(td_error2)
        masked_td_error = td_error2 * mask
        # important sampling for PER
        if self.use_per:
            per_weight = th.from_numpy(per_weight).unsqueeze(-1).to(device=self.device)
            masked_td_error = masked_td_error.sum(1) * per_weight

        loss = masked_td_error.sum() / mask.sum()

        # Optimise
        self.optimiser.zero_grad()
        loss.backward()
        grad_norm = th.nn.utils.clip_grad_norm_(self.params, self.args.grad_norm_clip)
        self.optimiser.step()


        ########## intrinsic
        # Calculate estimated Q-Values
        self.mac.int_agent.train()
        int_mac_out = []
        for t in range(batch.max_seq_length):
            int_agent_outs = self.mac.int_forward(batch, t=t)
            int_mac_out.append(int_agent_outs)
        int_mac_out = th.stack(int_mac_out, dim=1)  # Concat over time

        # Pick the Q-Values for the actions taken by each agent
        int_chosen_action_qvals = th.gather(int_mac_out[:, :-1], dim=3, index=actions).squeeze(
            3
        )  # Remove the last dim

        # Calculate the Q-Values necessary for the target
        with th.no_grad():
            self.target_mac.int_agent.train()
            int_target_mac_out = []
            for t in range(batch.max_seq_length):
                int_target_agent_outs = self.target_mac.int_forward(batch, t=t)
                int_target_mac_out.append(int_target_agent_outs)

            # We don't need the first timesteps Q-Value estimate for calculating targets
            int_target_mac_out = th.stack(int_target_mac_out, dim=1)  # Concat across time

            # Max over target Q-Values/ Double q learning
            int_mac_out_detach = int_mac_out.clone().detach()
            int_mac_out_detach[avail_actions == 0] = -9999999
            int_cur_max_actions = int_mac_out_detach.max(dim=3, keepdim=True)[1]
            int_target_max_qvals = th.gather(int_target_mac_out, 3, int_cur_max_actions).squeeze(3)

            # Calculate n-step Q-Learning targets
            int_target_max_qvals = self.int_target_mixer(int_target_max_qvals, batch["state"])

            
            # IG 
            s = batch["state"][:, :-1, :]
            next_s = batch["state"][:, 1:, :]
            bs, seq_len_cut, s_dim = s.shape
            rewards = batch["reward"][:, :-1].reshape(bs * seq_len_cut, -1)
            s = s.reshape(bs * seq_len_cut, -1)
            next_s = next_s.reshape(bs * seq_len_cut, -1)
            a = batch["actions"][:, :-1, :, 0].long()
            a_all = a[:, :seq_len_cut, :].reshape(bs * seq_len_cut, -1)
            
            inp = th.cat([s, self.target_ensemble_model.action_embed_net(a_all).flatten(-2)], dim=-1)
            labels = self._get_ensemble_targets(next_s,s,rewards)
            inp = self.input_normalizer.normalize(inp)
            for key, y in labels.items():
                labels[key] = self.output_normalizers[key].normalize(y)
            
            int_IG_rewards = self.get_intrinsic_reward(
                    inp=inp,
                    labels=labels
                )
            
            
            filter_mask = mask.reshape(-1).bool()
            int_IG_rewards = int_IG_rewards.reshape(-1,1)
            self.entropy_normalizer.update(int_IG_rewards.detach()[filter_mask])
            int_IG_rewards = self.entropy_normalizer.normalize(int_IG_rewards)


            # CI
            real_result = self.target_ensemble_model(inp)
            real = real_result['next_obs']
            mean_real,std_real = self.nanmean_std(real,dim=-1)
            actions_onehot = batch["actions_onehot"][:, :-1]
            b, t, n_agents, n_actions = actions_onehot.shape
            full_actions = th.ones((b, t, n_agents, n_actions, n_actions)) * th.eye(n_actions)
            full_actions = full_actions.type_as(s).to(a.device)
            full_s = batch["state"][:, :-1, :].unsqueeze(-2).repeat(1, 1, n_actions, 1)
            full_s = full_s.view(b*t*n_actions, -1)
            full_a = actions_onehot.unsqueeze(-2).repeat(1, 1, 1, n_actions, 1)
            cooperative_intrinsic = th.zeros((b*t, n_agents)).to(a.device)


            full_a_avail = avail_actions[:, :-1]
            for i in range(n_agents):
                ATE_a = (full_a.clone())
                ATE_a[..., i, :, :] = full_actions[..., i, :, :]
                ATE_a = ATE_a.transpose(-2, -3)
                ATE_action_ids = ATE_a.argmax(dim=-1).long()
                ATE_action_ids_reshaped = ATE_action_ids.view(-1, n_agents)
                counterfactual_inp = th.cat([full_s, self.target_ensemble_model.action_embed_net(ATE_action_ids_reshaped).flatten(-2)], dim=-1)
                counterfactual_inp = self.input_normalizer.normalize(counterfactual_inp)
                counterfactual = self.target_ensemble_model(counterfactual_inp)
                ds = counterfactual['next_obs']
                ds = ds.view(b,t,n_actions,s_dim,5)
                avail_mask = full_a_avail[:, :, i, :].unsqueeze(-1).unsqueeze(-1)
                avail_mask = avail_mask.expand(-1, -1, -1, ds.shape[3], ds.shape[4])
                ds_masked = ds.masked_fill(avail_mask == 0, float('nan')) 
                ds_flat = ds_masked.permute(0, 1, 3, 2, 4).reshape(b*t, s_dim, -1)
                mean,std = self.nanmean_std(ds_flat,dim=-1)
                div = self.KL_div(mean_real,std_real,mean,std)
                cooperative_intrinsic[:,i] = div

           

            cooperative_intrinsic = cooperative_intrinsic.mean(dim=-1).unsqueeze(-1)
            cooperative_intrinsic = cooperative_intrinsic.reshape(-1,1)
            self.cooperative_normalizer.update(cooperative_intrinsic.detach()[filter_mask])
            cooperative_intrinsic = self.cooperative_normalizer.normalize(cooperative_intrinsic)
            
            # # integrated
            int_IG_rewards = int_IG_rewards.view_as(batch["reward"][:, :-1])*mask
            cooperative_intrinsic = cooperative_intrinsic.view_as(batch["reward"][:, :-1])*mask


            alpha = self.args.alpha + (1 - self.args.alpha) * np.exp(-self.args.alpha_k * t_env / self.args.t_max)

            int_rewards = alpha*int_IG_rewards + (1-alpha)*cooperative_intrinsic
  

            if getattr(self.args, "q_lambda", False):
                qvals = th.gather(int_target_mac_out, 3, batch["actions"]).squeeze(3)
                qvals = self.int_target_mixer(qvals, batch["state"])
                int_targets = build_q_lambda_targets(
                    int_rewards,
                    terminated,
                    mask,
                    int_target_max_qvals,
                    qvals,
                    self.args.gamma,
                    self.args.td_lambda,
                )
            else:
                int_targets = build_td_lambda_targets(
                    int_rewards,
                    terminated,
                    mask,
                    int_target_max_qvals,
                    self.args.n_agents,
                    self.args.gamma,
                    self.args.td_lambda,
                )

        # Mixer
        int_chosen_action_qvals = self.int_mixer(int_chosen_action_qvals, batch["state"][:, :-1])

        int_td_error = int_chosen_action_qvals - int_targets.detach()
        int_td_error2 = 0.5 * int_td_error.pow(2)

        mask = mask.expand_as(int_td_error2)
        int_masked_td_error = int_td_error2 * mask

        int_loss = int_masked_td_error.sum() / mask.sum()

        # Optimise
        self.int_optimiser.zero_grad()
        int_loss.backward()
        int_grad_norm = th.nn.utils.clip_grad_norm_(self.int_params, self.args.grad_norm_clip)
        self.int_optimiser.step()

        total_loss = self.train_world(batch, self.args.predict_epoch)

        if (
            episode_num - self.last_target_update_episode
        ) / self.args.target_update_interval >= 1.0:
            self._update_targets()
            self.last_target_update_episode = episode_num
   

        if t_env - self.log_stats_t >= self.args.learner_log_interval:

            self.logger.log_stat(
                    "ensemble_loss", total_loss, t_env
                )
            self.logger.log_stat("loss_td", loss.item(), t_env)
            # self.logger.log_stat("loss", loss.item(), t_env)
            self.logger.log_stat("int_loss", int_loss.item(), t_env)
            # self.logger.log_stat("loss_int_c", loss_int_c.item(), t_env)
            self.logger.log_stat("grad_norm", grad_norm, t_env)
            self.logger.log_stat("int_grad_norm", int_grad_norm, t_env)
            # self.logger.log_stat("int_c_grad_norm", int_c_grad_norm, t_env)
            mask_elems = mask.sum().item()
            self.logger.log_stat(
                "td_error_abs", (masked_td_error.abs().sum().item() / mask_elems), t_env
            )
            self.logger.log_stat(
                "q_taken_mean",
                (chosen_action_qvals * mask).sum().item()
                / (mask_elems * self.args.n_agents),
                t_env,
            )
            self.logger.log_stat(
                "target_mean",
                (targets * mask).sum().item() / (mask_elems * self.args.n_agents),
                t_env,
            )
            self.logger.log_stat(
                "int_IG_rewards_mean", int_IG_rewards.reshape(-1,1)[filter_mask].mean().item(), t_env
            )
            self.logger.log_stat(
                "int_IG_rewards_std", int_IG_rewards.reshape(-1,1)[filter_mask].std().item(), t_env
            )
            self.logger.log_stat(
                "cooperative_intrinsic_mean", cooperative_intrinsic.reshape(-1,1)[filter_mask].mean().item(), t_env
            )
            self.logger.log_stat(
                "cooperative_intrinsic_std", cooperative_intrinsic.reshape(-1,1)[filter_mask].std().item(), t_env
            )

            self.logger.log_stat(
                "int_IG_rewards_mean_", self.entropy_normalizer.mean.item(), t_env
            )
            self.logger.log_stat(
                "int_IG_rewards_std_", self.entropy_normalizer.std.item(), t_env
            )
            self.logger.log_stat(
                "cooperative_intrinsic_mean_", self.cooperative_normalizer.mean.item(), t_env
            )
            self.logger.log_stat(
                "cooperative_intrinsic_std_", self.cooperative_normalizer.std.item(), t_env
            )


            self.log_stats_t = t_env

            # print estimated matrix
            if self.args.env == "one_step_matrix_game":
                print_matrix_status(batch, self.mixer, mac_out)

        # return info
        info = {}
        # calculate priority
        if self.use_per:
            if self.return_priority:
                info["td_errors_abs"] = rewards.sum(1).detach().to("cpu")
                # normalize to [0, 1]
                self.priority_max = max(
                    th.max(info["td_errors_abs"]).item(), self.priority_max
                )
                self.priority_min = min(
                    th.min(info["td_errors_abs"]).item(), self.priority_min
                )
                info["td_errors_abs"] = (info["td_errors_abs"] - self.priority_min) / (
                    self.priority_max - self.priority_min + 1e-5
                )
            else:
                info["td_errors_abs"] = (
                    ((td_error.abs() * mask).sum(1) / th.sqrt(mask.sum(1)))
                    .detach()
                    .to("cpu")
                )
        return info

    def _update_targets(self):
        self.target_mac.load_state(self.mac)
        if self.mixer is not None:
            self.target_mixer.load_state_dict(self.mixer.state_dict())
        if self.int_mixer is not None:
            self.int_target_mixer.load_state_dict(self.int_mixer.state_dict())
        self.target_ensemble_model.load_state_dict(self.ensemble_model.state_dict())
        self.logger.console_logger.info("Updated target network")

    def cuda(self):
        self.mac.cuda()
        self.target_mac.cuda()
        if self.mixer is not None:
            self.mixer.cuda()
            self.target_mixer.cuda()
        if self.int_mixer is not None:
            self.int_mixer.cuda()
            self.int_target_mixer.cuda()

    def save_models(self, path):
        self.mac.save_models(path)
        if self.mixer is not None:
            th.save(self.mixer.state_dict(), "{}/mixer.th".format(path))
        th.save(self.optimiser.state_dict(), "{}/opt.th".format(path))

    def load_models(self, path):
        self.mac.load_models(path)
        # Not quite right but I don't want to save target networks
        self.target_mac.load_models(path)
        if self.mixer is not None:
            self.mixer.load_state_dict(
                th.load(
                    "{}/mixer.th".format(path),
                    map_location=lambda storage, loc: storage,
                )
            )
        self.optimiser.load_state_dict(
            th.load("{}/opt.th".format(path), map_location=lambda storage, loc: storage)
        )

    def nanmean_std(self,o, dim, unbiased=True, eps=1e-8):
        count = th.sum(~th.isnan(o), dim=dim, keepdim=True).clamp(min=1)  # 有效样本数
        mean = th.nanmean(o, dim=dim, keepdim=True)  # 均值
        squared_diff = (o - mean) ** 2
        squared_diff[th.isnan(squared_diff)] = 0.0  # 把NaN视为0贡献

        if unbiased:
            count = th.clamp(count - 1, min=1)  # 无偏估计分母调整
        
        var = th.sum(squared_diff, dim=dim, keepdim=True) / count
        std = th.sqrt(var + eps)

        mean = mean.squeeze(dim)
        std = std.squeeze(dim)
        mean = th.nan_to_num(mean, nan=0.0)
        std = th.nan_to_num(std, nan=0.0)
        return mean,std
    
    def KL_div(self,p_mu, p_sigma, q_mu, q_sigma):
        """_summary_
        Args:
            p_mu (bs, dist_dim): _description_
            p_sigma (bs, dist_dim): _description_
            q_mu (bs, dist_dim): _description_
            q_sigma (bs, dist_dim): _description_
        """
        div = (
            th.log2(q_sigma)
            - th.log2(p_sigma)
            + (p_sigma**2 + (p_mu - q_mu) ** 2) / (2 * q_sigma**2)
            - 0.5
        )
        div = div.mean(dim=-1)
        # 归一化
        div = 1.0 - th.exp(-div)
        return div
    
    def Minus_div(self,p_mu, p_sigma, q_mu, q_sigma):
        """_summary_
        Args:
            p_mu (bs, dist_dim): _description_
            p_sigma (bs, dist_dim): _description_
            q_mu (bs, dist_dim): _description_
            q_sigma (bs, dist_dim): _description_
        """
        div = p_mu - q_mu
        div = div.mean(dim=-1)
        div = th.clamp(div, min=0.0, max=20.0)
        # 归一化
        div = 1.0 - th.exp(-div)
        
        return div

