# SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
# 
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
# Copyright (c) 2021 ETH Zurich, Nikita Rudin

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

from rsl_rl.modules import ActorCriticRMA
from rsl_rl.storage import RolloutStorage, ReplayBuffer
import wandb
from rsl_rl.utils import unpad_trajectories
import time

class RMS(object):
    def __init__(self, device, epsilon=1e-4, shape=(1,)):
        self.M = torch.zeros(shape, device=device)
        self.S = torch.ones(shape, device=device)
        self.n = epsilon

    def __call__(self, x):
        bs = x.size(0)
        delta = torch.mean(x, dim=0) - self.M
        new_M = self.M + delta * bs / (self.n + bs)
        new_S = (self.S * self.n + torch.var(x, dim=0) * bs + (delta**2) * self.n * bs / (self.n + bs)) / (self.n + bs)

        self.M = new_M
        self.S = new_S
        self.n += bs

        return self.M, self.S

class PPO:
    actor_critic: ActorCriticRMA
    def __init__(self,
                 env, 
                 actor_critic,
                 estimator,
                 estimator_paras,
                 amp_discriminator, amp_normalizer, amp_paras,
                #  depth_encoder,
                #  depth_encoder_paras,
                #  depth_actor,
                 num_learning_epochs=1,
                 num_mini_batches=1,
                 clip_param=0.2,
                 gamma=0.998,
                 lam=0.95,
                 value_loss_coef=1.0,
                 entropy_coef=0.0,
                 learning_rate=1e-3,
                 max_grad_norm=1.0,
                 use_clipped_value_loss=True,
                 schedule="fixed",
                 desired_kl=0.01,
                 device='cpu',
                 dagger_update_freq=20,
                 priv_reg_coef_schedual = [0, 0, 0],
                 **kwargs
                 ):

        self.env = env
        self.device = device

        self.desired_kl = desired_kl
        self.schedule = schedule
        self.learning_rate = learning_rate

        # PPO components
        self.actor_critic = actor_critic
        self.actor_critic.to(self.device)
        self.storage = None # initialized later
        self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
        self.transition = RolloutStorage.Transition()

        # PPO parameters
        self.clip_param = clip_param
        self.num_learning_epochs = num_learning_epochs
        self.num_mini_batches = num_mini_batches
        self.value_loss_coef = value_loss_coef
        self.entropy_coef = entropy_coef
        self.gamma = gamma
        self.lam = lam
        self.max_grad_norm = max_grad_norm
        self.use_clipped_value_loss = use_clipped_value_loss

        # Adaptation
        # self.hist_encoder_optimizer = optim.Adam(self.actor_critic.actor.history_encoder.parameters(), lr=learning_rate)
        # self.priv_reg_coef_schedual = priv_reg_coef_schedual

        self.counter = 0

        # Estimator
        self.estimator = estimator
        self.n_demo = estimator_paras["n_demo"]
        self.priv_latent_dim = estimator_paras["priv_latent_dim"]
        self.priv_states_dim = estimator_paras["priv_states_dim"]
        self.est_start = estimator_paras["priv_start"]
        self.num_prop = estimator_paras["prop_dim"]
        self.prop_start = estimator_paras["prop_start"]
        self.priv_prop_start = estimator_paras["priv_prop_start"]
        self.history_len = estimator_paras["history_len"]
        self.estimator_optimizer = optim.Adam(self.estimator.parameters(), lr=estimator_paras["learning_rate"])
        self.train_with_estimated_states = estimator_paras["train_with_estimated_states"]

        # AMP
        self.amp_discriminator = amp_discriminator
        self.amp_discriminator.to(self.device)
        # self.amp_transition = RolloutStorage.Transition()
        self.amp_storage = ReplayBuffer(amp_discriminator.input_dim, amp_paras["amp_replay_buffer_size"], device)
        self.amp_demo_storage = ReplayBuffer(amp_discriminator.input_dim, amp_paras["amp_demo_buffer_size"], device)
        self.amp_fetch_demo_batch_size = amp_paras["amp_demo_fetch_batch_size"]
        self.amp_learn_batch_size = amp_paras["amp_learn_batch_size"]
        self.init_amp_demo_buf()
        
        self.amp_normalizer = amp_normalizer
        self.amp_optimizer = optim.Adam(self.amp_discriminator.parameters(), lr=amp_paras["amp_learning_rate"], weight_decay=1e-3)
        self.amp_rew_scale = amp_paras["amp_reward_coef"]
        self.amp_grad_pen = amp_paras["amp_grad_pen"]

        # # Depth encoder
        # self.if_depth = depth_encoder != None
        # if self.if_depth:
        #     self.depth_encoder = depth_encoder
        #     self.depth_encoder_optimizer = optim.Adam(self.depth_encoder.parameters(), lr=depth_encoder_paras["learning_rate"])
        #     self.depth_encoder_paras = depth_encoder_paras
        #     self.depth_actor = depth_actor
        #     self.depth_actor_optimizer = optim.Adam([*self.depth_actor.parameters(), *self.depth_encoder.parameters()], lr=depth_encoder_paras["learning_rate"])


    def init_amp_demo_buf(self):
        buffer_size = self.amp_demo_storage.buffer_size
        num_batches = int(np.ceil(buffer_size / self.amp_fetch_demo_batch_size))

        for i in range(num_batches):
            curr_samples = self.env.fetch_amp_obs_demo(self.amp_fetch_demo_batch_size)
            self.amp_demo_storage.insert(curr_samples)
        return
    
    def update_amp_demos(self):
        curr_samples = self.env.fetch_amp_obs_demo(self.amp_fetch_demo_batch_size)
        self.amp_demo_storage.insert(curr_samples)
        return
    
    def init_storage(self, num_envs, num_transitions_per_env, actor_obs_shape, critic_obs_shape, action_shape):
        self.storage = RolloutStorage(num_envs, num_transitions_per_env, actor_obs_shape,  critic_obs_shape, action_shape, self.device)

    def test_mode(self):
        self.actor_critic.test()
    
    def train_mode(self):
        self.actor_critic.train()

    def act(self, obs, critic_obs, info, hist_encoding=False):
        if self.actor_critic.is_recurrent:
            self.transition.hidden_states = self.actor_critic.get_hidden_states()
        # Compute the actions and values, use proprio to compute estimated priv_states then actions, but store true priv_states
        if self.train_with_estimated_states:
            obs_est = obs.clone()
            # print('obs_est', obs_est.shape)
            # priv_states_estimated = self.estimator(obs_est[:, self.prop_start:self.prop_start+self.num_prop])

            # Transformer
            # priv_states_estimated = self.estimator(obs_est[:, self.prop_start - 2 * self.num_prop : self.prop_start + self.num_prop].reshape(obs_est.shape[0], -1, self.num_prop))

            # obs_est[:, self.est_start : self.est_start + self.priv_states_dim] = priv_states_estimated
            
            # z, vel = self.estimator.sample(obs_est[:, self.prop_start - self.history_len * self.num_prop : self.prop_start])
            # latent = torch.cat([z, vel], dim = 1)

            z, vel = self.estimator.sample(obs_est[:, self.prop_start - (self.history_len - 1) * self.num_prop : self.prop_start + self.num_prop])
            latent = torch.cat([z, vel], dim = 1)
            
            obs_est = torch.cat([obs_est[:, : self.est_start], latent], dim = 1)
            
            self.transition.actions = self.actor_critic.act(obs_est, hist_encoding).detach()
        else:
            self.transition.actions = self.actor_critic.act(obs, hist_encoding).detach()

        self.transition.values = self.actor_critic.evaluate(critic_obs).detach()
        self.transition.actions_log_prob = self.actor_critic.get_actions_log_prob(self.transition.actions).detach()
        self.transition.action_mean = self.actor_critic.action_mean.detach()
        self.transition.action_sigma = self.actor_critic.action_std.detach()

        # self.transition.observations = obs_est
        self.transition.observations = obs
        self.transition.critic_observations = critic_obs

        return self.transition.actions
    
    def process_env_step(self, rewards, dones, infos):
        reward_amp = self.calc_amp_rewards(infos["amp_obs"]).squeeze(1)
        rewards_total = rewards + self.amp_rew_scale * reward_amp * 0.02
        self.amp_storage.insert(infos["amp_obs"])
        

        self.transition.rewards = rewards_total.clone()
        self.transition.dones = dones
        # Bootstrapping on time outs
        if 'time_outs' in infos:
            self.transition.rewards += self.gamma * torch.squeeze(self.transition.values * infos['time_outs'].unsqueeze(1).to(self.device), 1)
        
        # Record the transition
        self.storage.add_transitions(self.transition)
        self.transition.clear()
        self.actor_critic.reset(dones)

        return rewards_total, rewards, reward_amp * 0.02, self.amp_rew_scale * reward_amp * 0.02
    
    def compute_returns(self, last_critic_obs):
        last_values= self.actor_critic.evaluate(last_critic_obs).detach()
        self.storage.compute_returns(last_values, self.gamma, self.lam)
    
    def update_amp(self):
        amp_policy_generator = self.amp_storage.feed_forward_generator(1,self.amp_learn_batch_size)
        amp_demo_generator = self.amp_demo_storage.feed_forward_generator(1,self.amp_learn_batch_size)
        sample_amp_policy, sample_amp_demo = next(amp_policy_generator), next(amp_demo_generator)
        if self.amp_normalizer is not None:
            with torch.no_grad():
                sample_amp_policy = self.amp_normalizer.normalize_torch(sample_amp_policy, self.device)
                sample_amp_demo = self.amp_normalizer.normalize_torch(sample_amp_demo, self.device)
        policy_d = self.amp_discriminator(sample_amp_policy)
        expert_d = self.amp_discriminator(sample_amp_demo)
        
        # # Original AMP Loss
        # expert_loss = torch.nn.MSELoss()(expert_d, torch.ones(expert_d.size(), device=self.device))
        # policy_loss = torch.nn.MSELoss()(policy_d, -1 * torch.ones(policy_d.size(), device=self.device))
        # amp_loss = 0.5 * (expert_loss + policy_loss)

        # Wasserstein-1 距离
        amp_loss = -torch.mean(expert_d) + torch.mean(policy_d)


        grad_pen_loss = self.amp_discriminator.compute_grad_pen(sample_amp_demo, lambda_=self.amp_grad_pen)
        amp_loss_pen = amp_loss + grad_pen_loss
        self.amp_optimizer.zero_grad()
        amp_loss_pen.backward()
        nn.utils.clip_grad_norm_(self.amp_discriminator.parameters(), self.max_grad_norm)
        self.amp_optimizer.step()
        if self.amp_normalizer is not None:
            self.amp_normalizer.update(sample_amp_policy.cpu().numpy())
            self.amp_normalizer.update(sample_amp_demo.cpu().numpy())
        return amp_loss, grad_pen_loss, policy_d.mean(), expert_d.mean()
    
    def update(self, kld_weight):
        mean_value_loss = 0
        mean_surrogate_loss = 0

        num_sub_steps = 2

        mean_estimator_loss = 0
        mean_recon_loss = 0.
        mean_predict_loss = 0.

        mean_discriminator_loss = 0
        mean_discriminator_acc = 0
        mean_priv_reg_loss = 0
        # AMP

        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        for sample in generator:
                amp_loss, grad_pen_loss, policy_d, expert_d = self.update_amp()

                obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
                old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch = sample
                
                obs_est_batch = obs_batch.clone()

                
                # z, vel = self.estimator.sample(obs_est_batch[:, self.prop_start - self.history_len * self.num_prop : self.prop_start])
                # latent = torch.cat([z, vel], dim = 1).detach()

                # For Mimic
                z, vel = self.estimator.sample(obs_est_batch[:, self.prop_start - (self.history_len - 1) * self.num_prop : self.prop_start + self.num_prop])
                latent = torch.cat([z, vel], dim = 1).detach()

                obs_est_batch = torch.cat([obs_est_batch[:, : self.est_start], latent], dim = 1)

                self.actor_critic.act(obs_est_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) # match distribution dimension


                # self.actor_critic.act(obs_batch, masks=masks_batch, hidden_states=hid_states_batch[0]) # match distribution dimension

                actions_log_prob_batch = self.actor_critic.get_actions_log_prob(actions_batch)
                value_batch = self.actor_critic.evaluate(critic_obs_batch, masks=masks_batch, hidden_states=hid_states_batch[1])
                mu_batch = self.actor_critic.action_mean
                sigma_batch = self.actor_critic.action_std
                entropy_batch = self.actor_critic.entropy
                

                # # Estimator
                # priv_states_predicted = self.estimator(obs_batch[:, self.prop_start:self.prop_start + self.num_prop])  # obs in batch is with true priv_states

                # #Transformer
                # # priv_states_predicted = self.estimator(obs_batch[:, self.prop_start - 2 * self.num_prop : self.prop_start + self.num_prop].reshape(obs_batch.shape[0], -1, self.num_prop))

                # estimator_loss = (priv_states_predicted - obs_batch[:, self.est_start:self.est_start+self.priv_states_dim]).pow(2).mean()
                
                # VAE Estimator
                for i in range(num_sub_steps):
                    next_obs_batch = obs_batch[:, self.prop_start : self.prop_start + self.num_prop]

                    priv_obs_batch = torch.cat([critic_obs_batch[:, self.priv_prop_start : self.priv_prop_start + self.num_prop], critic_obs_batch[:, self.priv_prop_start + self.num_prop + self.n_demo: self.priv_prop_start + self.num_prop + self.n_demo + self.priv_latent_dim]], dim = 1)

                    next_vel_zmp_batch = obs_batch[:, self.est_start  :  self.est_start + self.priv_states_dim]

                    # For VAE + Improved KL
                    # loss_dict = self.estimator.loss_fn(obs_batch[:, self.prop_start- self.history_len * self.num_prop : self.prop_start], next_obs_batch, next_vel_zmp_batch, 1.0)


                    # For Mimic
                    loss_dict = self.estimator.loss_fn(obs_batch[:, self.prop_start- (self.history_len - 1) * self.num_prop : self.prop_start + self.num_prop], priv_obs_batch, next_vel_zmp_batch, 1.0)
                    
                    estimator_loss = torch.mean(loss_dict['loss'])
                    recon_loss = torch.mean(loss_dict['recons_loss'])
                    predict_loss = torch.mean(loss_dict['vel_loss'])
                    
                    self.estimator_optimizer.zero_grad()
                    estimator_loss.backward()
                    nn.utils.clip_grad_norm_(self.estimator.parameters(), self.max_grad_norm)
                    self.estimator_optimizer.step()
                    
                    mean_estimator_loss += estimator_loss.item()
                    mean_recon_loss += recon_loss.item()
                    mean_predict_loss += predict_loss.item()
                
                # KL
                if self.desired_kl != None and self.schedule == 'adaptive':
                    with torch.inference_mode():
                        kl = torch.sum(
                            torch.log(sigma_batch / old_sigma_batch + 1.e-5) + (torch.square(old_sigma_batch) + torch.square(old_mu_batch - mu_batch)) / (2.0 * torch.square(sigma_batch)) - 0.5, axis=-1)
                        kl_mean = torch.mean(kl)

                        if kl_mean > self.desired_kl * 2.0:
                            self.learning_rate = max(1e-5, self.learning_rate / 1.5)
                        elif kl_mean < self.desired_kl / 2.0 and kl_mean > 0.0:
                            self.learning_rate = min(1e-2, self.learning_rate * 1.5)
                        
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = self.learning_rate


                # Surrogate loss
                ratio = torch.exp(actions_log_prob_batch - torch.squeeze(old_actions_log_prob_batch))
                surrogate = -torch.squeeze(advantages_batch) * ratio
                surrogate_clipped = -torch.squeeze(advantages_batch) * torch.clamp(ratio, 1.0 - self.clip_param,
                                                                                1.0 + self.clip_param)
                surrogate_loss = torch.max(surrogate, surrogate_clipped).mean()

                # Value function loss
                if self.use_clipped_value_loss:
                    value_clipped = target_values_batch + (value_batch - target_values_batch).clamp(-self.clip_param,
                                                                                                    self.clip_param)
                    value_losses = (value_batch - returns_batch).pow(2)
                    value_losses_clipped = (value_clipped - returns_batch).pow(2)
                    value_loss = torch.max(value_losses, value_losses_clipped).mean()
                else:
                    value_loss = (returns_batch - value_batch).pow(2).mean()

                loss = surrogate_loss + \
                       self.value_loss_coef * value_loss - \
                       self.entropy_coef * entropy_batch.mean()
                    #    priv_reg_coef * priv_reg_loss
                # loss = self.teacher_alpha * imitation_loss + (1 - self.teacher_alpha) * loss

                # Gradient step
                self.optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(self.actor_critic.parameters(), self.max_grad_norm)
                self.optimizer.step()

                mean_value_loss += value_loss.item()
                mean_surrogate_loss += surrogate_loss.item()
                # mean_estimator_loss += estimator_loss.item()
                # mean_recon_loss += recon_loss.item()
                # mean_predict_loss += predict_loss.item()

                # mean_priv_reg_loss += priv_reg_loss.item()
                mean_discriminator_loss += 0
                mean_discriminator_acc += 0
                

        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_value_loss /= num_updates
        mean_surrogate_loss /= num_updates
        mean_estimator_loss /= (num_updates * num_sub_steps)
        mean_recon_loss /= (num_updates * num_sub_steps)
        mean_predict_loss /= (num_updates * num_sub_steps)

        # mean_priv_reg_loss /= num_updates
        mean_discriminator_loss /= num_updates
        mean_discriminator_acc /= num_updates
        
        self.storage.clear()
        self.update_counter()
        return mean_value_loss, mean_surrogate_loss, mean_estimator_loss, mean_recon_loss, mean_predict_loss,  mean_discriminator_loss, mean_discriminator_acc, \
            amp_loss, grad_pen_loss, policy_d, expert_d, self.entropy_coef
                # mean_value_loss, mean_surrogate_loss, mean_estimator_loss, mean_discriminator_loss, mean_discriminator_acc, mean_priv_reg_loss, priv_reg_coef, \
                  

    def update_dagger(self):
        mean_hist_latent_loss = 0
        if self.actor_critic.is_recurrent:
            generator = self.storage.reccurent_mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        else:
            generator = self.storage.mini_batch_generator(self.num_mini_batches, self.num_learning_epochs)
        for obs_batch, critic_obs_batch, actions_batch, target_values_batch, advantages_batch, returns_batch, old_actions_log_prob_batch, \
            old_mu_batch, old_sigma_batch, hid_states_batch, masks_batch in generator:
                with torch.inference_mode():
                    self.actor_critic.act(obs_batch, hist_encoding=True, masks=masks_batch, hidden_states=hid_states_batch[0])

                # Adaptation module update
                with torch.inference_mode():
                    priv_latent_batch = self.actor_critic.actor.infer_priv_latent(obs_batch)
                hist_latent_batch = self.actor_critic.actor.infer_hist_latent(obs_batch)
                hist_latent_loss = (priv_latent_batch.detach() - hist_latent_batch).norm(p=2, dim=1).mean()
                self.hist_encoder_optimizer.zero_grad()
                hist_latent_loss.backward()
                nn.utils.clip_grad_norm_(self.actor_critic.actor.history_encoder.parameters(), self.max_grad_norm)
                self.hist_encoder_optimizer.step()
                
                mean_hist_latent_loss += hist_latent_loss.item()
        num_updates = self.num_learning_epochs * self.num_mini_batches
        mean_hist_latent_loss /= num_updates
        self.storage.clear()
        self.update_counter()
        return mean_hist_latent_loss

    def update_depth_encoder(self, depth_latent_batch, scandots_latent_batch):
        # Depth encoder ditillation
        if self.if_depth:
            # TODO: needs to save hidden states
            depth_encoder_loss = (scandots_latent_batch.detach() - depth_latent_batch).norm(p=2, dim=1).mean()

            self.depth_encoder_optimizer.zero_grad()
            depth_encoder_loss.backward()
            nn.utils.clip_grad_norm_(self.depth_encoder.parameters(), self.max_grad_norm)
            self.depth_encoder_optimizer.step()
            return depth_encoder_loss.item()
    
    def update_depth_actor(self, actions_student_batch, actions_teacher_batch, yaw_student_batch, yaw_teacher_batch):
        if self.if_depth:
            depth_actor_loss = (actions_teacher_batch.detach() - actions_student_batch).norm(p=2, dim=1).mean()
            yaw_loss = (yaw_teacher_batch.detach() - yaw_student_batch).norm(p=2, dim=1).mean()

            loss = depth_actor_loss + yaw_loss

            self.depth_actor_optimizer.zero_grad()
            loss.backward()
            nn.utils.clip_grad_norm_(self.depth_actor.parameters(), self.max_grad_norm)
            self.depth_actor_optimizer.step()
            return depth_actor_loss.item(), yaw_loss.item()
    
    def update_depth_both(self, depth_latent_batch, scandots_latent_batch, actions_student_batch, actions_teacher_batch):
        if self.if_depth:
            depth_encoder_loss = (scandots_latent_batch.detach() - depth_latent_batch).norm(p=2, dim=1).mean()
            depth_actor_loss = (actions_teacher_batch.detach() - actions_student_batch).norm(p=2, dim=1).mean()

            depth_loss = depth_encoder_loss + depth_actor_loss

            self.depth_actor_optimizer.zero_grad()
            depth_loss.backward()
            nn.utils.clip_grad_norm_([*self.depth_actor.parameters(), *self.depth_encoder.parameters()], self.max_grad_norm)
            self.depth_actor_optimizer.step()
            return depth_encoder_loss.item(), depth_actor_loss.item()
    
    def update_counter(self):
        self.counter += 1
    
    def calc_amp_rewards(self, amp_obs):
        with torch.no_grad():
            disc_logits = self.amp_discriminator(amp_obs)
            # prob = 1 / (1 + torch.exp(-disc_logits)) 
            # disc_r = -torch.log(torch.maximum(1 - prob, torch.tensor(0.0001, device=self.device)))

            # This is benificial for stable training
            # disc_r = torch.clamp(1 - (1/4) * torch.square(disc_logits - 1), min=0)
            disc_r = disc_logits
        return disc_r
    
    def compute_apt_reward(self, source, target):

        b1, b2 = source.size(0), target.size(0)
        # (b1, 1, c) - (1, b2, c) -> (b1, 1, c) - (1, b2, c) -> (b1, b2, c) -> (b1, b2)
        # sim_matrix = torch.norm(source[:, None, ::2].view(b1, 1, -1) - target[None, :, ::2].view(1, b2, -1), dim=-1, p=2)
        # sim_matrix = torch.norm(source[:, None, :2].view(b1, 1, -1) - target[None, :, :2].view(1, b2, -1), dim=-1, p=2)
        sim_matrix = torch.norm(source[:, None, :].view(b1, 1, -1) - target[None, :, :].view(1, b2, -1), dim=-1, p=2)

        reward, _ = sim_matrix.topk(self.knn_k, dim=1, largest=False, sorted=True)  # (b1, k)

        if not self.knn_avg:  # only keep k-th nearest neighbor
            reward = reward[:, -1]
            reward = reward.reshape(-1, 1)  # (b1, 1)
            if self.rms:
                moving_mean, moving_std = self.disc_state_rms(reward)
                reward = reward / moving_std
            reward = torch.clamp(reward - self.knn_clip, 0)  # (b1, )
        else:  # average over all k nearest neighbors
            reward = reward.reshape(-1, 1)  # (b1 * k, 1)
            if self.rms:
                moving_mean, moving_std = self.disc_state_rms(reward)
                reward = reward / moving_std
            reward = torch.clamp(reward - self.knn_clip, 0)
            reward = reward.reshape((b1, self.knn_k))  # (b1, k)
            reward = reward.mean(dim=1)  # (b1,)
        reward = torch.log(reward + 1.0)
        return reward