from collections import deque
from functools import partial
from rlf.args import str2bool

import rlf.rl.utils as rutils
import torch
import torch.nn.functional as F
from goal_prox.method.prox_func import ProxFunc
from goal_prox.method.value_traj_dataset import *
from rlf.algos.nested_algo import NestedAlgo
from rlf.algos.on_policy.ppo import PPO

from tqdm import tqdm
import rlf.il.utils as iutils
import rlf.algos.utils as autils
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
import itertools

class DiscountedProxIL(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_pf_base=None, add_dropout=False):
        super().__init__(
            [DiscountedProxFunc(get_pf_base=get_pf_base, add_dropout=add_dropout), agent_updater],
            designated_rl_idx=1,
        )


class DiscountedProxFunc(ProxFunc):
    def prox_infer_loss(self, next_prox, cur_prox_pred):
        pf_type = self.args.pf_reward_type.lower()
        eps = 1e-20  # for 'airl' / 'nofinal_airl'

        if pf_type == 'reg':
            # diff_prox = next - cur  ->  cur = next - diff
            diff = next_prox - cur_prox_pred
            cur_prox_target = next_prox - diff               # == cur_prox_pred  (no change)

        elif pf_type == 'nodiff':
            # pseudo-target = next_prox directly
            cur_prox_target = next_prox

        elif pf_type == 'nofinal':
            # diff reward only: cur = next - diff
            diff = next_prox - cur_prox_pred
            cur_prox_target = next_prox - diff               # same as 'reg'

        elif pf_type == 'nofinal_airl':
            # AIRL-style logits transformation
            s_next = next_prox
            s_cur  = cur_prox_pred
            next_logit = (s_next + eps).log() - (1 - s_next + eps).log()
            cur_logit  = (s_cur  + eps).log() - (1 - s_cur  + eps).log()
            diff = next_logit - cur_logit
            cur_prox_target = next_logit - diff              # back to same value

        elif pf_type == 'nofinal_pen':
            # constant penalty, use current prediction as target
            cur_prox_target = cur_prox_pred.clone()

        elif pf_type == 'none':
            # treat prox as next_prox (similar to 'nodiff')
            cur_prox_target = next_prox

        elif pf_type == 'pen':
            # only constant penalty, again keep as is
            cur_prox_target = cur_prox_pred.clone()

        elif pf_type == 'airl':
            s_next = next_prox
            cur_prox_target = (s_next + eps).log() - (1 - s_next + eps).log()

        else:
            raise ValueError(f"Unrecognized pf_reward_type '{self.args.pf_reward_type}'")



    def _prox_func_iter(self, data_batch):
        states = data_batch["state"].to(self.args.device)
        proximity = data_batch["prox"].to(self.args.device)
        actions = data_batch["actions"].to(self.args.device)

        guess_proximity = self._get_prox_vals(states, actions)

        n_ensembles = guess_proximity.shape[0]

        loss = F.mse_loss(
            guess_proximity.view(n_ensembles, -1),
            proximity.view(1, -1).repeat(n_ensembles, 1),
        )
        
        if self.args.add_regularizer_on_proximity:
            # only consider loss from the actions in expert distribution to be 0
            # Average over ensemble predictions: shape [batch, 1]
            mean_guess_prox = guess_proximity.mean(dim=0)
            masked_squared_error_seperate = (mean_guess_prox - proximity.view(-1, 1)) ** 2  # shape: [batch, 1]
            if 'MiniGrid' in self.args.env_name:
                # Only keep entries with actions < 4
                valid_mask = (actions < 4).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
                
            elif 'MBRLmaze2d' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.1
                valid_mask = (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
                
            elif 'FetchPickAndPlaceDiffHoldoutTS150' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.1
                valid_mask = (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
            elif 'FetchPushEnvCustomTS500' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.05
                valid_mask = (actions.abs() < 0.05).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
            else:
                raise NotImplementedError

            frac = self.anneal_regularizer_on_proximity
            loss = frac * valid_loss + (1.0 - frac) * loss
        else:
            loss = F.mse_loss(
                guess_proximity.view(n_ensembles, -1),
                proximity.view(1, -1).repeat(n_ensembles, 1),
            )
        
        
        
        return loss

    def _get_current_prox_value(self, next_prox):
        # -------------------------------------------------
        # delta for each transition  (you can define it differently if you wish)
        # here we use L2-norm of the action vector
        delta = self.args.pf_delta       # shape [B,1]

        # -------------------------------------------------
        # Build cur_prox_target from next_prox, per dmode
        # -------------------------------------------------
        pf_type = self.args.dmode.lower()

        if pf_type == 'exp':
            # next_prox = cur_prox / delta  --> cur_prox_target = next_prox * delta
            cur_prox_target = torch.clamp(next_prox * delta, min=0.0)

        elif pf_type == 'linear':
            # next_prox = cur_prox + delta  --> cur_prox_target = next_prox - delta
            cur_prox_target = torch.clamp(next_prox - delta, min=0.0)

        elif pf_type == 'big':
            # same additive rule as 'linear'
            cur_prox_target = torch.clamp(next_prox - delta, min=0.0)

        elif pf_type == 'one':
            # target is constant 1
            cur_prox_target = torch.ones_like(next_prox)

        else:
            raise ValueError(f"Unrecognized dmode '{self.args.dmode}'")
        
        return cur_prox_target


    def _get_prox_val_fn(self):
        use_fn = None
        if self.args.dmode == "exp":
            use_fn = exp_discounted
            def_delta = 0.95
        elif self.args.dmode == "linear":
            use_fn = linear_discounted
            def_delta = 0.001
        elif self.args.dmode == "big":
            use_fn = partial(big_discounted, start_val=self.args.start_delta)
            def_delta = 0.001
        elif self.args.dmode == "exp_subsample":
            use_fn = partial(exp_discounted_subsample, subsample_freq=self.args.pf_subsample_freq)
            def_delta = 0.95
        else:
            raise ValueError("Must specify discounting mode")

        if self.args.pf_delta is None:
            self.args.pf_delta = def_delta

        return partial(use_fn, delta=self.args.pf_delta)


    def _get_traj_dataset(self, traj_load_path):
        self.compute_prox_fn = self._get_prox_val_fn()
        if self.args.add_regularizer_on_proximity_infer_from_next:
            return ValueTrajIncludeNextStateDataset(traj_load_path, self.compute_prox_fn, self.args)
        else:
            return ValueTrajDataset(traj_load_path, self.compute_prox_fn, self.args)

    def compute_good_traj_prox(self, obs, actions):
        return torch.tensor(compute_discounted_prox(len(obs), self.compute_prox_fn))

    def compute_bad_traj_prox(self, obs, actions):
        return torch.zeros(len(obs))

    def get_add_args(self, parser):
        super().get_add_args(parser)
        #########################################
        # New args
        parser.add_argument("--dmode", type=str, default="exp")
        parser.add_argument("--start-delta", type=float, default=1.0)
        parser.add_argument("--pf-delta", type=float, default=0.95)
        parser.add_argument("--pf-subsample-freq", type=int, default=10)


class EstimatedProxIL(NestedAlgo):
    def __init__(self, agent_updater=PPO(), get_pf_base=None, add_dropout=False):
        super().__init__(
            [EstimatedProxFunc(get_pf_base=get_pf_base, add_dropout=add_dropout), agent_updater],
            designated_rl_idx=1,
        )


class EstimatedProxFunc(ProxFunc):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        # init confidence_queue, when querying the confidence value, we will pick the minimum value from the queue
        self.confidence_queue = deque(maxlen=100)
    
    def confidence_value(self):
        if len(self.confidence_queue) == 0:
            return 0
        else:
            return min(self.confidence_queue)

    def prox_infer_loss(self, next_prox, cur_prox_pred):
        pf_type = self.args.pf_reward_type.lower()
        eps = 1e-20  # for 'airl' / 'nofinal_airl'

        if pf_type == 'reg':
            # diff_prox = next - cur  ->  cur = next - diff
            diff = next_prox - cur_prox_pred
            cur_prox_target = next_prox - diff               # == cur_prox_pred  (no change)

        elif pf_type == 'nodiff':
            # pseudo-target = next_prox directly
            cur_prox_target = next_prox

        elif pf_type == 'nofinal':
            # diff reward only: cur = next - diff
            diff = next_prox - cur_prox_pred
            cur_prox_target = next_prox - diff               # same as 'reg'

        elif pf_type == 'nofinal_airl':
            # AIRL-style logits transformation
            s_next = next_prox
            s_cur  = cur_prox_pred
            next_logit = (s_next + eps).log() - (1 - s_next + eps).log()
            cur_logit  = (s_cur  + eps).log() - (1 - s_cur  + eps).log()
            diff = next_logit - cur_logit
            cur_prox_target = next_logit - diff              # back to same value

        elif pf_type == 'nofinal_pen':
            # constant penalty, use current prediction as target
            cur_prox_target = cur_prox_pred.clone()

        elif pf_type == 'none':
            # treat prox as next_prox (similar to 'nodiff')
            cur_prox_target = next_prox

        elif pf_type == 'pen':
            # only constant penalty, again keep as is
            cur_prox_target = cur_prox_pred.clone()

        elif pf_type == 'airl':
            s_next = next_prox
            cur_prox_target = (s_next + eps).log() - (1 - s_next + eps).log()

        else:
            raise ValueError(f"Unrecognized pf_reward_type '{self.args.pf_reward_type}'")



    def first_train(self, log, eval_policy, env_interface):
        if self.args.pf_load_path is not None:
            self.prox_func.load_state_dict(torch.load(self.args.pf_load_path)['prox_func'])
            print('Loaded proximity function from %s' % self.args.pf_load_path)
            return

        losses = []

        # Train the proximity function from scratch
        rutils.pstart_sep()
        print('Pre-training proximity function')

        self.prox_func.train()

        for epoch_i in tqdm(range(self.args.pre_num_epochs)):
            epoch_losses = []
            for expert_batch in self.expert_train_loader:
                loss = self._prox_func_iter(expert_batch, est_success_conf=True)
                epoch_losses.append(loss.item())

                self.opt.zero_grad()
                loss.backward()
                self.opt.step()

            avg_loss = np.mean(epoch_losses)
            losses.append(avg_loss)
            print('Epoch %i: Loss %.5f' % (epoch_i, avg_loss))

            plot_title = None
            if self.val_train_loader is not None:
                avg_val_loss = self._compute_val_loss()
                plot_title = 'Val Loss %.5f' % avg_val_loss
                print(plot_title)

        # Save a figure of the loss curve

        self.debug_viz.plot(0, ["expert"], self._get_plot_funcs())

        rutils.pend_sep()


    def _update_reward_func(self, storage):
        if not self.args.pf_with_agent:
            # Don't use agent experience to update the proximity function.
            return {}

        take_count = self.args.exp_sample_size

        if self.should_use_failure() and len(self.failure_agent_trajs) < take_count:
            # We don't have enough agent experience yet to update the proximity
            # function.
            return {}

        success_trajs = iutils.mix_data(self.success_agent_trajs, self.expert_dataset,
                self.args.exp_succ_scale * take_count, 0.5)
        success_sampler = BatchSampler(SubsetRandomSampler(range(take_count)),
                self.args.traj_batch_size, drop_last=True)

        success_trajs = iutils.convert_list_dict(success_trajs,
                self.args.device)

        if self.should_use_failure():
            failure_trajs = self.failure_agent_trajs
            if len(self.failure_agent_trajs) > take_count:
                failure_trajs = np.random.choice(failure_trajs,
                        take_count, replace=False)
            failure_sampler = BatchSampler(SubsetRandomSampler(range(take_count)),
                    self.args.traj_batch_size, drop_last=True)
            failure_trajs = iutils.convert_list_dict(failure_trajs,
                    self.args.device)
        else:
            failure_sampler = itertools.repeat({})

        log_vals = defaultdict(list)
        self.prox_func.train()
        for epoch_i in range(self.args.pf_num_epochs):
            for success_idx, failure_idx in zip(success_sampler, failure_sampler):
                viz_dict = {}
                combined_loss = 0.0

                success_agent_batch = iutils.select_idx_from_dict(success_idx,
                                                                success_trajs)
                viz_dict['success'] = success_agent_batch
                expert_loss = self._prox_func_iter(success_agent_batch, est_success_conf=True)
                log_vals['expert_loss'].append(expert_loss.item())
                assert self.args.pf_expert_coef >= 0 and self.args.pf_expert_coef <= 2, 'pf_expert_coef should be within range of [0, 2]'
                combined_loss += self.args.pf_expert_coef * expert_loss

                if self.should_use_failure():
                    failure_agent_batch = iutils.select_idx_from_dict(failure_idx,
                                                                    failure_trajs)
                    agent_loss = self._prox_func_iter(failure_agent_batch)
                    log_vals['agent_loss'].append(agent_loss.item())
                    viz_dict['failure'] = failure_agent_batch

                    combined_loss += (2.0 - self.args.pf_expert_coef) * agent_loss

                    grad_pen = 0
                    if self.args.disc_grad_pen != 0:
                        grad_pen = self.args.disc_grad_pen * autils.wass_grad_pen(
                                success_agent_batch['state'],
                                success_agent_batch['actions'],
                                failure_agent_batch['state'],
                                failure_agent_batch['actions'],
                                self.args.action_input, self._get_prox_vals)
                    combined_loss += grad_pen

                self.debug_viz.add(viz_dict)

                self.opt.zero_grad()
                combined_loss.backward()
                self.opt.step()

                log_vals['combined_loss'].append(combined_loss.item())

        # here compute the confident state percentage
        if self.should_use_failure() and len(self.confidence_queue) > 0:
            conf_value = self.confidence_value()
            cnt = 0 # random do 20 batches
            for failure_idx in failure_sampler:
                failure_agent_batch = iutils.select_idx_from_dict(failure_idx,
                                                                failure_trajs)
                states = failure_agent_batch["state"]
                actions = failure_agent_batch["actions"]
                uncert = self._get_prox_uncert(states, actions).squeeze()
                confidence = 1 - uncert
                confident_count = torch.sum(confidence > conf_value).item()
                log_vals['conf_percent'].append(confident_count / len(confidence))
                cnt += 1
                if cnt >= 20:
                    break

        for k in log_vals:
            log_vals[k] = np.mean(log_vals[k])
        if self.val_train_loader is not None:
            log_vals['expert_val_loss'] = self._compute_val_loss()

        if self.update_i % self.args.pf_viz_interval == 0:
            self.debug_viz.plot(self.update_i, ['success', 'failure'],
                                self._get_plot_funcs())
        # Still clear the viz statistics, even if we did not log.
        self.debug_viz.reset()

        if len(self.avg_proxs) != 0:
            log_vals['avg_traj_prox'] = np.mean(self.avg_proxs)
        if len(self.start_proxs) != 0:
            log_vals['start_traj_proxs'] = np.mean(self.start_proxs)
        self.start_proxs = []
        self.avg_proxs = []

        return log_vals


    def _prox_func_iter(self, data_batch, est_success_conf=False):
        states = data_batch["state"].to(self.args.device)
        proximity = data_batch["prox"].to(self.args.device)
        actions = data_batch["actions"].to(self.args.device)

        guess_proximity = self._get_prox_vals(states, actions)

        if est_success_conf:
            # if success, we need to estimate the confidence
            uncert = self._get_prox_uncert(states, actions).squeeze()
            confidence = 1 - uncert
            # push all values into the queue
            self.confidence_queue.extend(confidence.detach().cpu().numpy().tolist())

        n_ensembles = guess_proximity.shape[0]

        loss = F.mse_loss(
            guess_proximity.view(n_ensembles, -1),
            proximity.view(1, -1).repeat(n_ensembles, 1),
        )
        
        if self.args.add_regularizer_on_proximity:
            # only consider loss from the actions in expert distribution to be 0
            # Average over ensemble predictions: shape [batch, 1]
            mean_guess_prox = guess_proximity.mean(dim=0)
            masked_squared_error_seperate = (mean_guess_prox - proximity.view(-1, 1)) ** 2  # shape: [batch, 1]
            if 'MiniGrid' in self.args.env_name:
                # Only keep entries with actions < 4
                valid_mask = (actions < 4).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
                
            elif 'MBRLmaze2d' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.1
                valid_mask = (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
                
            elif 'FetchPickAndPlaceDiffHoldoutTS150' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.1
                valid_mask = (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
            elif 'FetchPushEnvCustomTS500' in self.args.env_name:
                # Only keep entries with each dim of actions < 0.05
                valid_mask = (actions.abs() < 0.05).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                # Apply mask to both guess and target
                masked_squared_error = masked_squared_error_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    valid_loss = masked_squared_error.sum() / valid_mask.sum()
                else:
                    valid_loss = torch.tensor(0.0, device=self.args.device)
            else:
                raise NotImplementedError

            frac = self.anneal_regularizer_on_proximity
            loss = frac * valid_loss + (1.0 - frac) * loss

            if self.args.add_regularizer_on_proximity_infer_from_next:
                # Average over ensemble predictions: shape [batch, 1]

                with torch.no_grad():
                    next_states = data_batch["next_state"].to(self.args.device)
                    next_actions = actions # in the prox function, action is a useless input. Thus I directly used the current action
                    guess_next_proximity = self._get_prox_vals(next_states, next_actions)

                    mean_guess_next_prox = guess_next_proximity.mean(dim=0)
                    # Compute the current proximity target based on next proximity
                    cur_prox_target = self._get_current_prox_value(mean_guess_next_prox)
                
                # calculate the mse loss between the current proximity target and the guessed proximity
                mse_error_infer_from_next_seperate = (cur_prox_target - mean_guess_prox) ** 2  # shape: [batch, 1]

                # we need to build a relationship between current and next
                if 'MiniGrid' in self.args.env_name:
                    # Only keep entries with actions >= 4
                    valid_mask = (actions >= 4).float()  # shape: [batch, 1]
                elif 'MBRLmaze2d' in self.args.env_name:
                    # Only keep entries with each dim of actions < 0.1
                    valid_mask = 1 - (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]
                    
                elif 'FetchPickAndPlaceDiffHoldoutTS150' in self.args.env_name:
                    # Only keep entries with each dim of actions < 0.1
                    valid_mask = 1 - (actions.abs() < 0.1).all(dim=1, keepdim=True).float()  # shape: [batch, 1]

                elif 'FetchPushEnvCustomTS500' in self.args.env_name:
                    # Only keep entries with each dim of actions < 0.05
                    valid_mask = 1 - (actions.abs() < 0.05).all(dim=1, keepdim=True).float()  # shape: [batch, 1]
                else:
                    raise NotImplementedError
                

                # Apply mask to both guess and target
                mse_error_infer_from_next = mse_error_infer_from_next_seperate * valid_mask

                # Avoid dividing by zero if no valid entries
                if valid_mask.sum() > 0:
                    ooc_loss = mse_error_infer_from_next.sum() / valid_mask.sum()
                else:
                    ooc_loss = torch.tensor(0.0, device=self.args.device)

                loss = frac * ooc_loss + loss
        


        else:
            loss = F.mse_loss(
                guess_proximity.view(n_ensembles, -1),
                proximity.view(1, -1).repeat(n_ensembles, 1),
            )
        
        
        
        return loss


    def _get_current_prox_value(self, next_prox):
        # -------------------------------------------------
        # delta for each transition  (you can define it differently if you wish)
        # here we use L2-norm of the action vector
        delta = self.args.pf_delta       # shape [B,1]

        # -------------------------------------------------
        # Build cur_prox_target from next_prox, per dmode
        # -------------------------------------------------
        pf_type = self.args.dmode.lower()

        if pf_type == 'exp':
            # next_prox = cur_prox / delta  --> cur_prox_target = next_prox * delta
            cur_prox_target = torch.clamp(next_prox * delta, min=0.0)

        elif pf_type == 'linear':
            # next_prox = cur_prox + delta  --> cur_prox_target = next_prox - delta
            cur_prox_target = torch.clamp(next_prox - delta, min=0.0)

        elif pf_type == 'big':
            # same additive rule as 'linear'
            cur_prox_target = torch.clamp(next_prox - delta, min=0.0)

        elif pf_type == 'one':
            # target is constant 1
            cur_prox_target = torch.ones_like(next_prox)

        else:
            raise ValueError(f"Unrecognized dmode '{self.args.dmode}'")
        
        return cur_prox_target


    def _get_prox_val_fn(self):
        use_fn = None
        if self.args.dmode == "exp":
            use_fn = exp_discounted
            def_delta = 0.95
        elif self.args.dmode == "linear":
            use_fn = linear_discounted
            def_delta = 0.001
        elif self.args.dmode == "big":
            use_fn = partial(big_discounted, start_val=self.args.start_delta)
            def_delta = 0.001
        elif self.args.dmode == "exp_subsample":
            use_fn = partial(exp_discounted_subsample, subsample_freq=self.args.pf_subsample_freq)
            def_delta = 0.95
        else:
            raise ValueError("Must specify discounting mode")

        if self.args.pf_delta is None:
            self.args.pf_delta = def_delta

        return partial(use_fn, delta=self.args.pf_delta)


    def _get_traj_dataset(self, traj_load_path):
        self.compute_prox_fn = self._get_prox_val_fn()
        if self.args.add_regularizer_on_proximity_infer_from_next:
            return ValueTrajIncludeNextStateDataset(traj_load_path, self.compute_prox_fn, self.args)
        else:
            return ValueTrajDataset(traj_load_path, self.compute_prox_fn, self.args)

    def compute_good_traj_prox(self, obs, actions):
        return torch.tensor(compute_discounted_prox(len(obs), self.compute_prox_fn))
    

    def compute_bad_traj_prox(self, obs, actions):
        """
        For bad trajectories, assign interpolated proximity values between confident states.
        Each segment between two confident states is assigned linearly interpolated proximity
        values between their corresponding proximity scores.
        """
        with torch.no_grad():
            # Step 1: Get uncertainty and turn into normalized confidence
            uncert = self._get_prox_uncert(obs, actions).squeeze()  # [T]
            confidence = 1 - uncert

            # Step 2: Identify confident indices
            threshold = self.confidence_value()
            valid_mask = (confidence > threshold)
            confident_idxs = valid_mask.nonzero(as_tuple=False).squeeze(-1)

            # Step 3: Fallback if no confident points
            T = len(obs)
            prox = torch.zeros(T, device=obs.device)
            if len(confident_idxs) == 0:
                return prox

            # Step 4: Get proximity values for confident states
            prox_confident = self._get_prox_vals(obs[confident_idxs], actions)
            prox_confident = prox_confident.mean(dim=0).squeeze(-1)  # [N_conf]

            # Step 5: Interpolate between confident segments
            for i in range(len(confident_idxs) - 1):
                start_idx = confident_idxs[i].item()
                end_idx = confident_idxs[i + 1].item()
                start_val = prox_confident[i].item()
                end_val = prox_confident[i + 1].item()

                if self.args.interpolation_dmode == 'linear':
                    steps = end_idx - start_idx + 1
                    interp_vals = torch.linspace(start_val, end_val, steps=steps, device=obs.device)
                    prox[start_idx:end_idx + 1] = interp_vals
                elif self.args.interpolation_dmode == 'exp':
                    # Exponential interpolation
                    # Avoid log(0)
                    start_val = max(start_val, 1e-6)
                    start_val = min(start_val, 1)
                    end_val = max(end_val, 1e-6)
                    start_val = min(start_val, 1)

                    log_start = torch.log(torch.tensor(start_val, device=obs.device))
                    log_end = torch.log(torch.tensor(end_val, device=obs.device))
                    steps = end_idx - start_idx + 1
                    interp_log_vals = torch.linspace(log_start, log_end, steps=steps, device=obs.device)
                    interp_vals = torch.exp(interp_log_vals)
                    prox[start_idx:end_idx + 1] = interp_vals
                else:
                    raise ValueError(f"Unrecognized interpolation_dmode '{self.args.interpolation_dmode}'")

            # Optional: assign confident values at those exact indices (ensures numerical consistency)
            for i, idx in enumerate(confident_idxs):
                prox[idx] = prox_confident[i]
            
            # for each prox, we will have a  proabability to set them to 0
            probs = torch.rand_like(prox)
            # if probs <= self.policy.pf_est_prob, set them to 0
            prox[probs <= self.pf_est_prob] = 0.0

            return prox


    def get_add_args(self, parser):
        super().get_add_args(parser)
        #########################################
        # New args
        parser.add_argument("--dmode", type=str, default="exp")
        parser.add_argument("--interpolation-dmode", type=str, default="exp")
        parser.add_argument("--start-delta", type=float, default=1.0)
        parser.add_argument("--pf-delta", type=float, default=None)
        parser.add_argument("--pf-subsample-freq", type=int, default=10)

