import numpy as np
import torch
import torch.nn as nn

from data import Batch
from ppo import PPOPolicy
from trainer import Trainer, test_episode
from mmd import mmd
from utils import Logger
from collector import Collector
from data import ReplayBuffer, VectorBuffer
import warnings


class NTLPPOPolicy(PPOPolicy):
    # 计算mmd时仅使用actor feature
    def __init__(self, actor, critic, optim, dist_fn,
        eps_clip=0.2, discount_factor=0.99, reward_normalization=False,
        action_scaling=True, action_bound_method="clip", deterministic_eval=False,
        dual_clip=None, value_clip=False, advantage_normalization=True, recompute_advantage=False,
        max_batchsize=256, observation_space=None, action_space=None, lr_scheduler=None,
        vf_coef=0.5, ent_coef=0.01, max_grad_norm=None, gae_lambda=0.95, split_env=None,
        **kwargs):
        super(NTLPPOPolicy, self).__init__(actor, critic, optim, dist_fn,\
            eps_clip, discount_factor, reward_normalization, action_scaling, \
            action_bound_method, deterministic_eval, dual_clip, value_clip, \
            advantage_normalization, recompute_advantage, vf_coef, ent_coef, \
            max_grad_norm, gae_lambda, max_batchsize, observation_space, \
            action_space, lr_scheduler, **kwargs)
        self.split_env = split_env

    def process_fn(self, batch, buffer, **kwargs) :
        batch = super().process_fn(batch, buffer)
        s_batch = self.split_batch(batch)
        return s_batch

    def split_batch(self, batch, **kwargs):
        env_id = np.zeros(batch.info.shape[0])
        for i in range(batch.info.shape[0]):
            env_id[i] = batch.info[i]["env_id"]

        s_batch = dict()
        for k,v in self.split_env.items():
            index = (env_id==v[0])
            for i in range(1,v.shape[0]):
                index += (env_id==v[i])
            s_batch[k] = batch[index]
        s_batch.update(all=batch)  
        return s_batch

    def learn(self, batch, batch_size, repeat, **kwargs):
        s_batch = batch["source"]
        t_batch = batch["target"]
        batch = batch["all"]
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        ratios, kl_data, mmd_data, euk_data = [], [], [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer)
                batch = self.split_batch(batch)
                s_batch = batch["source"]
                t_batch = batch["target"]
                batch = batch["all"]                
            for s_minibatch,t_minibatch in zip(\
                    s_batch.split(int(batch_size/2), merge_last=True),\
                    t_batch.split(int(batch_size/2), merge_last=True)):
                minibatch = Batch()
                for k in s_minibatch.keys():
                    if k == 'adv':
                        continue
                    if isinstance(s_minibatch[k],np.ndarray):
                        minibatch[k] = np.concatenate((s_minibatch[k],t_minibatch[k]))
                    elif isinstance(s_minibatch[k],torch.Tensor):
                        minibatch[k] = torch.cat([s_minibatch[k],t_minibatch[k]])
                # calculate loss for actor
                # dist = self(minibatch).dist
                minibatch_forward = self(minibatch)
                dist, features = minibatch_forward.dist, minibatch_forward.feature
                s_features, t_features = torch.chunk(features, 2, dim=0)
                mmd_loss = mmd(s_features, t_features)

                t_minibatch.adv *= -0.1*np.clip(mmd_loss.detach().item(), self._eps_clip, 999999)
                # concat s_adv and t_adv
                if isinstance(s_minibatch['adv'],np.ndarray):
                    minibatch['adv'] = np.concatenate((s_minibatch['adv'],t_minibatch['adv']))
                elif isinstance(s_minibatch['adv'],torch.Tensor):
                    minibatch['adv'] = torch.cat([s_minibatch['adv'],t_minibatch['adv']])

                if self._norm_adv:
                    mean, std = minibatch.adv.mean(), minibatch.adv.std()
                    minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * minibatch.adv
                surr2 = ratio.clamp(
                    1.0 - self._eps_clip, 1.0 + self._eps_clip
                ) * minibatch.adv
                if self._dual_clip:
                    clip1 = torch.min(surr1, surr2)
                    clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
                    clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = self.critic(minibatch.obs).flatten()
                if self._value_clip:
                    v_clip = minibatch.v_s + \
                        (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
                    vf1 = (minibatch.returns - value).pow(2)
                    vf2 = (minibatch.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (minibatch.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
                ratios.append(ratio.mean().item())
                kl_data.append(ratio.log().mean().item()) 
                mmd_data.append(mmd_loss.item())            
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
            "ratio": ratios,
            "kl": kl_data,  
            "mmd": mmd_data,         
        } 

class NTLPPOPolicyCriticMmd(NTLPPOPolicy):
    # 计算mmd时不仅使用actor feature, 同时使用critic feature
    def _compute_returns(self, batch, buffer):
        v_s, v_s_ = [], []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                v_s.append(self.critic(minibatch.obs)[0])
                v_s_.append(self.critic(minibatch.obs_next)[0])
        batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
        v_s = batch.v_s.cpu().numpy()
        v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
        if self._rew_norm:  # unnormalize v_s & v_s_
            v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
            v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
        unnormalized_returns, advantages = self.compute_episodic_return(
            batch,
            buffer,
            v_s_,
            v_s,
            gamma=self._gamma,
            gae_lambda=self._lambda
        )
        if self._rew_norm:
            batch.returns = unnormalized_returns / \
                np.sqrt(self.ret_rms.var + self._eps)
            self.ret_rms.update(unnormalized_returns)
            # print(f'mean:{self.ret_rms.mean}, var:{self.ret_rms.var}')
        else:
            batch.returns = unnormalized_returns
        device = batch.v_s.device
        dtype = batch.v_s.dtype
        batch.returns = torch.from_numpy(batch.returns).type(dtype).to(device)
        batch.adv = torch.from_numpy(advantages).type(dtype).to(device)
        return batch

    def learn(self, batch, batch_size, repeat, **kwargs):
        s_batch = batch["source"]
        t_batch = batch["target"]
        batch = batch["all"]
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        ratios, kl_data, mmd_data, euk_data = [], [], [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer)
                batch = self.split_batch(batch)
                s_batch = batch["source"]
                t_batch = batch["target"]
                batch = batch["all"]                
            for s_minibatch,t_minibatch in zip(\
                    s_batch.split(int(batch_size/2), merge_last=True),\
                    t_batch.split(int(batch_size/2), merge_last=True)):
                minibatch = Batch()
                for k in s_minibatch.keys():
                    if k == 'adv':
                        continue
                    if isinstance(s_minibatch[k],np.ndarray):
                        minibatch[k] = np.concatenate((s_minibatch[k],t_minibatch[k]))
                    elif isinstance(s_minibatch[k],torch.Tensor):
                        minibatch[k] = torch.cat([s_minibatch[k],t_minibatch[k]])
                # calculate loss for actor
                # dist = self(minibatch).dist
                minibatch_forward = self(minibatch)
                dist, features = minibatch_forward.dist, minibatch_forward.feature
                s_features, t_features = torch.chunk(features, 2, dim=0)
                mmd_loss = mmd(s_features, t_features)
                value, critic_features = self.critic(minibatch.obs)
                c_s_features, c_t_features = torch.chunk(critic_features, 2, dim=0)
                c_mmd_loss = mmd(c_s_features, c_t_features)
                t_minibatch.adv *= -0.1*(
                    np.clip(mmd_loss.detach().item(), self._eps_clip, 999999) + np.clip(c_mmd_loss.detach().item(), self._eps_clip, 999999)
                )

                if self._norm_adv:
                    # mean, std = minibatch.adv.mean(), minibatch.adv.std()
                    # minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                    mean, std = s_minibatch.adv.mean(), s_minibatch.adv.std()
                    s_minibatch.adv = (s_minibatch.adv - mean) / std  # per-batch norm
                    mean, std = t_minibatch.adv.mean(), t_minibatch.adv.std()
                    t_minibatch.adv = (t_minibatch.adv - mean) / std  # per-batch norm

                # concat s_adv and t_adv
                if isinstance(s_minibatch['adv'],np.ndarray):
                    minibatch['adv'] = np.concatenate((s_minibatch['adv'],t_minibatch['adv']))
                elif isinstance(s_minibatch['adv'],torch.Tensor):
                    minibatch['adv'] = torch.cat([s_minibatch['adv'],t_minibatch['adv']])


                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * minibatch.adv
                surr2 = ratio.clamp(
                    1.0 - self._eps_clip, 1.0 + self._eps_clip
                ) * minibatch.adv
                if self._dual_clip:
                    clip1 = torch.min(surr1, surr2)
                    clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
                    clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = value.flatten()
                if self._value_clip:
                    v_clip = minibatch.v_s + \
                        (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
                    vf1 = (minibatch.returns - value).pow(2)
                    vf2 = (minibatch.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (minibatch.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
                ratios.append(ratio.mean().item())
                kl_data.append(ratio.log().mean().item()) 
                mmd_data.append(mmd_loss.item()+c_mmd_loss.item())            
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
            "ratio": ratios,
            "kl": kl_data,  
            "mmd": mmd_data,         
        } 

class NTLPPOPolicyFeature(NTLPPOPolicy):
    # 使用actor feature作为critic的输入的一部分
    # 计算mmd时仅使用actor feature
    def _compute_returns(self, batch, buffer):
        v_s, v_s_ = [], []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                if isinstance(minibatch.feature, np.ndarray):
                    feature_temp = minibatch.feature
                else:
                    feature_temp = minibatch.feature.detach().cpu().numpy()
                if isinstance(minibatch.feature_next, np.ndarray):
                    feature_next_temp = minibatch.feature_next
                else:
                    feature_next_temp = minibatch.feature_next.detach().cpu().numpy()
                v_s.append(self.critic(np.concatenate([minibatch.obs, feature_temp], axis=1)))
                v_s_.append(self.critic(np.concatenate([minibatch.obs_next, feature_next_temp], axis=1)))

        batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
        v_s = batch.v_s.cpu().numpy()
        v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
        if self._rew_norm:  # unnormalize v_s & v_s_
            v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
            v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
        unnormalized_returns, advantages = self.compute_episodic_return(
            batch,
            buffer,
            v_s_,
            v_s,
            gamma=self._gamma,
            gae_lambda=self._lambda
        )
        if self._rew_norm:
            batch.returns = unnormalized_returns / \
                np.sqrt(self.ret_rms.var + self._eps)
            self.ret_rms.update(unnormalized_returns)
            # print(f'mean:{self.ret_rms.mean}, var:{self.ret_rms.var}')
        else:
            batch.returns = unnormalized_returns
        device = batch.v_s.device
        dtype = batch.v_s.dtype
        batch.returns = torch.from_numpy(batch.returns).type(dtype).to(device)
        batch.adv = torch.from_numpy(advantages).type(dtype).to(device)
        # print(f"adv-- min:{batch.adv.min()} max:{batch.adv.max()} mean: {batch.adv.mean()}")
        return batch
    
    def learn(self, batch, batch_size, repeat, **kwargs):
        s_batch = batch["source"]
        t_batch = batch["target"]
        batch = batch["all"]
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        ratios, kl_data, mmd_data, euk_data = [], [], [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer)
                batch = self.split_batch(batch)
                s_batch = batch["source"]
                t_batch = batch["target"]
                batch = batch["all"]                
            for s_minibatch,t_minibatch in zip(\
                    s_batch.split(int(batch_size/2), merge_last=True),\
                    t_batch.split(int(batch_size/2), merge_last=True)):
                minibatch = Batch()
                for k in s_minibatch.keys():
                    if k == 'adv':
                        continue
                    if isinstance(s_minibatch[k],np.ndarray):
                        minibatch[k] = np.concatenate((s_minibatch[k],t_minibatch[k]))
                    elif isinstance(s_minibatch[k],torch.Tensor):
                        minibatch[k] = torch.cat([s_minibatch[k],t_minibatch[k]])
                # calculate loss for actor
                # dist = self(minibatch).dist
                minibatch_forward = self(minibatch)
                dist, features = minibatch_forward.dist, minibatch_forward.feature
                s_features, t_features = torch.chunk(features, 2, dim=0)
                mmd_loss = mmd(s_features, t_features)

                t_minibatch.adv *= -0.1*np.clip(mmd_loss.detach().item(), self._eps_clip, 999999)
                # concat s_adv and t_adv
                if isinstance(s_minibatch['adv'],np.ndarray):
                    minibatch['adv'] = np.concatenate((s_minibatch['adv'],t_minibatch['adv']))
                elif isinstance(s_minibatch['adv'],torch.Tensor):
                    minibatch['adv'] = torch.cat([s_minibatch['adv'],t_minibatch['adv']])

                if self._norm_adv:
                    mean, std = minibatch.adv.mean(), minibatch.adv.std()
                    minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * minibatch.adv
                surr2 = ratio.clamp(
                    1.0 - self._eps_clip, 1.0 + self._eps_clip
                ) * minibatch.adv
                if self._dual_clip:
                    clip1 = torch.min(surr1, surr2)
                    clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
                    clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                # value = self.critic(minibatch.obs).flatten()
                value = self.critic(np.concatenate([minibatch.obs, minibatch.feature.detach().cpu().numpy()], axis=1)).flatten()
                # value = self.critic(torch.cat([torch.as_tensor(minibatch.obs, device=features.device), features], dim=1)).flatten()
                if self._value_clip:
                    v_clip = minibatch.v_s + \
                        (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
                    vf1 = (minibatch.returns - value).pow(2)
                    vf2 = (minibatch.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (minibatch.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
                ratios.append(ratio.mean().item())
                kl_data.append(ratio.log().mean().item()) 
                mmd_data.append(mmd_loss.item())            
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
            "ratio": ratios,
            "kl": kl_data,  
            "mmd": mmd_data,         
        }

class NTLPPOPolicyFeatureCriticMmd(NTLPPOPolicy):
    # 使用actor feature作为critic的输入的一部分
    # 计算mmd时不仅使用actor feature, 同时使用critic feature
    def _compute_returns(self, batch, buffer):
        v_s, v_s_ = [], []
        with torch.no_grad():
            for minibatch in batch.split(self._batch, shuffle=False, merge_last=True):
                if isinstance(minibatch.feature, np.ndarray):
                    feature_temp = minibatch.feature
                else:
                    feature_temp = minibatch.feature.detach().cpu().numpy()
                if isinstance(minibatch.feature_next, np.ndarray):
                    feature_next_temp = minibatch.feature_next
                else:
                    feature_next_temp = minibatch.feature_next.detach().cpu().numpy()
                v_s.append(self.critic(np.concatenate([minibatch.obs, feature_temp], axis=1))[0])
                v_s_.append(self.critic(np.concatenate([minibatch.obs_next, feature_next_temp], axis=1))[0])

        batch.v_s = torch.cat(v_s, dim=0).flatten()  # old value
        v_s = batch.v_s.cpu().numpy()
        v_s_ = torch.cat(v_s_, dim=0).flatten().cpu().numpy()
        if self._rew_norm:  # unnormalize v_s & v_s_
            v_s = v_s * np.sqrt(self.ret_rms.var + self._eps)
            v_s_ = v_s_ * np.sqrt(self.ret_rms.var + self._eps)
        unnormalized_returns, advantages = self.compute_episodic_return(
            batch,
            buffer,
            v_s_,
            v_s,
            gamma=self._gamma,
            gae_lambda=self._lambda
        )
        if self._rew_norm:
            batch.returns = unnormalized_returns / \
                np.sqrt(self.ret_rms.var + self._eps)
            self.ret_rms.update(unnormalized_returns)
            # print(f'mean:{self.ret_rms.mean}, var:{self.ret_rms.var}')
        else:
            batch.returns = unnormalized_returns
        device = batch.v_s.device
        dtype = batch.v_s.dtype
        batch.returns = torch.from_numpy(batch.returns).type(dtype).to(device)
        batch.adv = torch.from_numpy(advantages).type(dtype).to(device)
        # print(f"adv-- min:{batch.adv.min()} max:{batch.adv.max()} mean: {batch.adv.mean()}")
        return batch

    def learn(self, batch, batch_size, repeat, **kwargs):
        s_batch = batch["source"]
        t_batch = batch["target"]
        batch = batch["all"]
        losses, clip_losses, vf_losses, ent_losses = [], [], [], []
        ratios, kl_data, mmd_data, euk_data = [], [], [], []
        for step in range(repeat):
            if self._recompute_adv and step > 0:
                batch = self._compute_returns(batch, self._buffer)
                batch = self.split_batch(batch)
                s_batch = batch["source"]
                t_batch = batch["target"]
                batch = batch["all"]                
            for s_minibatch,t_minibatch in zip(\
                    s_batch.split(int(batch_size/2), merge_last=True),\
                    t_batch.split(int(batch_size/2), merge_last=True)):
                minibatch = Batch()
                for k in s_minibatch.keys():
                    if k == 'adv':
                        continue
                    if isinstance(s_minibatch[k],np.ndarray):
                        minibatch[k] = np.concatenate((s_minibatch[k],t_minibatch[k]))
                    elif isinstance(s_minibatch[k],torch.Tensor):
                        minibatch[k] = torch.cat([s_minibatch[k],t_minibatch[k]])
                # calculate loss for actor
                # dist = self(minibatch).dist
                minibatch_forward = self(minibatch)
                dist, features = minibatch_forward.dist, minibatch_forward.feature
                s_features, t_features = torch.chunk(features, 2, dim=0)
                mmd_loss = mmd(s_features, t_features)
                value, critic_features = self.critic(np.concatenate([minibatch.obs, minibatch.feature.detach().cpu().numpy()], axis=1))
                c_s_features, c_t_features = torch.chunk(critic_features, 2, dim=0)
                c_mmd_loss = mmd(c_s_features, c_t_features)


                if self._norm_adv:
                    # mean, std = minibatch.adv.mean(), minibatch.adv.std()
                    # minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                    mean, std = s_minibatch.adv.mean(), s_minibatch.adv.std()
                    s_minibatch.adv = (s_minibatch.adv - mean) / std  # per-batch norm
                    mean, std = t_minibatch.adv.mean(), t_minibatch.adv.std()
                    t_minibatch.adv = (t_minibatch.adv - mean) / std  # per-batch norm

                ### transfer comments
                t_minibatch.adv *= -0.1*(
                    np.clip(mmd_loss.detach().item(), self._eps_clip, 999999) + np.clip(c_mmd_loss.detach().item(), self._eps_clip, 999999)
                )

                # concat s_adv and t_adv
                if isinstance(s_minibatch['adv'],np.ndarray):
                    minibatch['adv'] = np.concatenate((s_minibatch['adv'],t_minibatch['adv']))
                elif isinstance(s_minibatch['adv'],torch.Tensor):
                    minibatch['adv'] = torch.cat([s_minibatch['adv'],t_minibatch['adv']])

                # if self._norm_adv:
                #     mean, std = minibatch.adv.mean(), minibatch.adv.std()
                #     minibatch.adv = (minibatch.adv - mean) / std  # per-batch norm
                ratio = (dist.log_prob(minibatch.act) -
                         minibatch.logp_old).exp().float()
                ratio = ratio.reshape(ratio.size(0), -1).transpose(0, 1)
                surr1 = ratio * minibatch.adv
                surr2 = ratio.clamp(
                    1.0 - self._eps_clip, 1.0 + self._eps_clip
                ) * minibatch.adv
                if self._dual_clip:
                    clip1 = torch.min(surr1, surr2)
                    clip2 = torch.max(clip1, self._dual_clip * minibatch.adv)
                    clip_loss = -torch.where(minibatch.adv < 0, clip2, clip1).mean()
                else:
                    clip_loss = -torch.min(surr1, surr2).mean()
                # calculate loss for critic
                value = value.flatten()
                if self._value_clip:
                    v_clip = minibatch.v_s + \
                        (value - minibatch.v_s).clamp(-self._eps_clip, self._eps_clip)
                    vf1 = (minibatch.returns - value).pow(2)
                    vf2 = (minibatch.returns - v_clip).pow(2)
                    vf_loss = torch.max(vf1, vf2).mean()
                else:
                    vf_loss = (minibatch.returns - value).pow(2).mean()
                # calculate regularization and overall loss
                ent_loss = dist.entropy().mean()
                loss = clip_loss + self._weight_vf * vf_loss \
                    - self._weight_ent * ent_loss
                self.optim.zero_grad()
                loss.backward()
                if self._grad_norm:  # clip large gradient
                    nn.utils.clip_grad_norm_(
                        self._actor_critic.parameters(), max_norm=self._grad_norm
                    )
                self.optim.step()
                clip_losses.append(clip_loss.item())
                vf_losses.append(vf_loss.item())
                ent_losses.append(ent_loss.item())
                losses.append(loss.item())
                ratios.append(ratio.mean().item())
                kl_data.append(ratio.log().mean().item()) 
                mmd_data.append(mmd_loss.item()+c_mmd_loss.item())            
        # update learning rate if lr_scheduler is given
        if self.lr_scheduler is not None:
            if self.lr_scheduler.get_lr()[0]>1e-5:
                self.lr_scheduler.step()

        return {
            "loss": losses,
            "loss/clip": clip_losses,
            "loss/vf": vf_losses,
            "loss/ent": ent_losses,
            "ratio": ratios,
            "kl": kl_data,  
            "mmd": mmd_data,         
        }

class NTLTrainer(Trainer):

    def test_step(self):
        """Perform one testing step."""
        assert self.episode_per_test is not None
        assert self.test_collector is not None
        test_result = test_episode(
            self.policy, self.test_collector,
            self.episode_per_test
        )
        # print(test_result["rews"], test_result['rew'])
        test_num = test_result["rews"].shape[0]
        print("source test: ", test_result["rews"][:int(test_num/2)], test_result["rews"][:int(test_num/2)].mean())
        print("target test: ", test_result["rews"][int(test_num/2):], test_result["rews"][int(test_num/2):].mean())
        source_rew, source_rew_std = test_result["rews"][:int(test_num/2)].mean(), test_result["rews"][:int(test_num/2)].std()
        target_rew, target_rew_std = test_result["rews"][int(test_num/2):].mean(), test_result["rews"][int(test_num/2):].std()
        if self.best_epoch < 0 or (self.best_reward < source_rew and source_rew_std < 500):
            self.best_epoch = self.epoch
            self.best_reward = float(source_rew)
            self.best_reward_std = source_rew_std
            if self.save_best_fn:
                self.save_best_fn(self.policy)

        test_stat = {
            "source_test_reward": source_rew,
            "source_test_reward_std": source_rew_std,
            "target_test_reward": target_rew,
            "target_test_reward_std": target_rew_std,
            "best_reward": self.best_reward,
            "best_reward_std": self.best_reward_std,
            "best_epoch": self.best_epoch
        }
        return test_stat

class NTLLogger(Logger):
    def log_test_data(self, test_result, step):
        log_data = {
            "test/s_reward": test_result["source_test_reward"],
            "test/s_reward_std": test_result["source_test_reward_std"],
            "test/s_best_reward": test_result["best_reward"],
            "test/t_reward": test_result["target_test_reward"],
            "test/t_reward_std": test_result["target_test_reward_std"]
        }
        self.write(step, log_data)

    def log_update_data(self, update_result, step):
        total_step = len(update_result["loss"])
        log_data = {
            "update/mean/loss": np.array(update_result["loss"]).mean(),
            "update/mean/clip_loss": np.array(update_result["loss/clip"]).mean(), 
            "update/mean/vf_loss": np.array(update_result["loss/vf"]).mean(),
            "update/mean/ent_loss": np.array(update_result["loss/ent"]).mean(),
            "update/mean/ratio": np.array(update_result["ratio"]).mean(),
            "update/mean/kl": np.array(update_result["kl"]).mean(),
            "update/mean/mmd": np.array(update_result["mmd"]).mean(),
        }   
        self.write(step,log_data)

class NTLCollectorFeature(Collector):
    def collect(self, n_step=None, n_episode=None):
        if n_step is not None:
            if not n_step % self.env_num == 0:
                warnings.warn(
                    f"n_step={n_step} is not a multiple of #env ({self.env_num}), "
                    "which may cause extra transitions collected into the buffer."
                )
            ready_env_ids = np.arange(self.env_num)
        elif n_episode is not None:
            ready_env_ids = np.arange(min(self.env_num, n_episode))
            self.data = self.data[:min(self.env_num, n_episode)]
        else:
            raise TypeError(
                "Please specify at least one (either n_step or n_episode) "
                "in AsyncCollector.collect()."
            )

        step_count = 0
        episode_count = 0
        episode_rews = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
        episode_lens = np.zeros_like(ready_env_ids, dtype=np.float64) - np.inf
 

        while True:
            assert len(self.data) == len(ready_env_ids)

            with torch.no_grad():  
                result = self.policy(self.data)

            act = result.act.detach().cpu().numpy()
            feature = result.feature.detach().cpu().numpy()
            self.data.update(act=act,feature=feature)

            action_remap = self.policy.map_action(self.data.act)
            result = self.env.step(action_remap, ready_env_ids) 
            obs_next, rew, done, info = result
            feature_next = self.policy(Batch(obs=obs_next)).feature.detach().cpu().numpy()

            self.data.update(obs_next=obs_next, rew=rew, done=done, info=info, feature_next=feature_next)

            ep_rew, ep_len = self.buffer.add(self.data, ready_env_ids)

            step_count += len(ready_env_ids)

            self.data.obs = self.data.obs_next.copy()
            if np.any(done):
                env_ind_local = np.where(done)[0]
                env_ind_global = ready_env_ids[env_ind_local]
                episode_count += len(env_ind_local)
                episode_rews[env_ind_global] = ep_rew[env_ind_local].copy()
                episode_lens[env_ind_global] = ep_len[env_ind_local].copy()
                obs_reset = self.env.reset(env_ind_global)
                self.data.obs[env_ind_local] = obs_reset

                if n_episode:
                    surplus_env_num = len(ready_env_ids) - (n_episode - episode_count)
                    if surplus_env_num > 0:
                        mask = np.ones_like(ready_env_ids, dtype=bool)
                        mask[env_ind_local[:surplus_env_num]] = False
                        ready_env_ids = ready_env_ids[mask]
                        self.data = self.data[mask]

            if (n_step and step_count >= n_step) or \
                    (n_episode and episode_count >= n_episode):
                break

        self.collect_step += step_count
        self.collect_episode += episode_count

        if n_episode:
            self.data = Batch(
                obs={}, act={}, rew={}, done={}, obs_next={}, info={}
            )
            self.reset_env()

        if episode_count > 0:
            rews = np.ma.masked_equal(episode_rews, -np.inf)
            rew_mean, rew_std = rews.mean(), rews.std()
            lens = np.ma.masked_equal(episode_lens, -np.inf)
            len_mean, len_std = lens.mean(), lens.std()
        else:
            rews  = np.array([])
            rew_mean = rew_std  = 0
            lens  = np.array([])
            len_mean = len_std  = 0

        return {
            "n/ep": episode_count,
            "n/st": step_count,
            "rews": rews,
            "rew": rew_mean,
            "rew_std": rew_std,
            "lens": lens,
            "len": len_mean,
            "len_std": len_std
        }

class NTLReplayBufferFeature(ReplayBuffer):
    def __init__(self, size):
        self.batch_dict = dict(obs={},act={},obs_next={},rew={},done={},info={},feature={},feature_next={})
        self.size = size
        self._data = Batch(self.batch_dict)
        self._indices = np.arange(size)
        self.ptr = 0
        self._ep_rew = 0
        self._ep_len = 0

class NTLVectorBufferFeature(VectorBuffer):
    def __init__(self, total_size, buffer_num):
        self.buffer_list = []
        self.buffer_num = buffer_num
        size = int(np.ceil(total_size / buffer_num))
        for _ in range(self.buffer_num):
            self.buffer_list.append(NTLReplayBufferFeature(size,))