import os
import dill as pickle
import numpy as np
import torch

from task.abs_task import AbsTask
from utils.utils_fn import nest_dict


class FeedbackCollector:
    """Class for storing all feedback types"""

    def __init__(
        self,
        num_dem_segments: int,
        pref_buffer_capacity: int,
        preference_sampling_params: dict,
        data_filepath: str,
        device: str,
        env: AbsTask,
        score_name: str = "rl_sum",
        num_validation_pref: int = 0,
        cpl_sampling: bool = False,
        is_dense_pref: bool = False,
        use_cpu_storage: bool = False,
        only_best_model: bool = False,
    ) -> None:
        self.device = device

        self.env_obs_dim = env.obs_dim
        self.env_action_dim = env.action_dim

        self.data_filepath = data_filepath
        with open(self.data_filepath, "rb") as f:
            data = np.load(f, allow_pickle=True)
            self.data = nest_dict(data)

        self.use_cpu_storage = use_cpu_storage
        self.storage_device = "cpu" if self.use_cpu_storage else self.device

        # demonstrations
        self.num_dem_segments = num_dem_segments
        self.score_name = score_name

        self.cpl_sampling = cpl_sampling
        self.num_validation_pref = num_validation_pref

        self.only_best_model = only_best_model

        self._extract_and_label_demonstrations(
            data=self.data,
            num_dem_segments=self.num_dem_segments,
            score_name=self.score_name,
        )

        self.preference_sampling_params = preference_sampling_params
        if self.cpl_sampling:
            pref_buffer_capacity = self.num_dem_segments // 2

        ### pref_buff
        self.pref_buffer_last_index = 0
        self.pref_buffer_full = False
        self.pref_buffer_capacity = int(pref_buffer_capacity)
        self.pref_1 = torch.tensor([], device=self.storage_device)
        self.pref_2 = torch.tensor([], device=self.storage_device)
        self.pref_label = torch.tensor([], device=self.storage_device)
        self.pref_1_idx = np.array([], dtype=int)
        self.pref_2_idx = np.array([], dtype=int)
        ###

        self.is_dense_pref = is_dense_pref
        self.traj_rankings = None
        self.dense_pref_pairs = []
        self.pref_sa_dataset = None
        self.pref_r_dataset = None
        self.pref_segment_dataset = None
        self.pref_r_segment_dataset = None

    def _extract_and_label_demonstrations(
        self, data, num_dem_segments, label=1.0, score_name="rl_sum", action_eps=1e-5
    ):
        if len(data) == 0:
            return torch.tensor(
                torch.zeros((0, self.env_obs_dim + self.env_action_dim + 1)),
                device=self.device,
            )  # +1 for the label

        total_num_segments = data["obs"].shape[0]
        if self.cpl_sampling:
            random_indices = np.arange(num_dem_segments)
        else:
            if "label" in data.keys():
                unique_labels = np.unique(data["label"][:, 0])
                num_labels = len(unique_labels)
                if self.only_best_model:
                    if "best_model" in unique_labels:
                        self.num_dem_segments = num_dem_segments // num_labels
                        num_labels = 1
                        unique_labels = ["best_model"]
                        num_dem_segments = self.num_dem_segments
                num_individual_segments = num_dem_segments // num_labels
                random_indices_all = []
                for label_data in unique_labels:
                    idx_label = np.where(data["label"][:, 0] == label_data)[0]
                    random_indices_label = np.random.choice(
                        idx_label, size=num_individual_segments, replace=False
                    )
                    random_indices_all.append(random_indices_label)
                random_indices = np.concatenate(random_indices_all)
                np.random.shuffle(random_indices)
            else:
                random_indices = np.random.choice(
                    total_num_segments, size=num_dem_segments, replace=False
                )

        if self.num_validation_pref > 0:
            val_random_indices = np.random.choice(
                total_num_segments, size=2 * self.num_validation_pref, replace=False
            )
            random_indices = np.concatenate([random_indices, val_random_indices])

        lim = 1 - action_eps
        clipped_action = np.clip(data["action"], a_min=-lim, a_max=lim)
        actions = (
            torch.from_numpy(clipped_action[random_indices])
            .float()
            .to(self.storage_device)
        )
        obss = (
            torch.from_numpy(data["obs"][random_indices])
            .float()
            .to(self.storage_device)
        )
        rewards = (
            torch.from_numpy(data["reward"][random_indices]).float().to(self.device)
        )
        dones = torch.from_numpy(data["done"][random_indices]).float().to(self.device)
        # states = torch.from_numpy(data["state"][random_indices]).float().to(self.device)
        scores = (
            torch.from_numpy(data[score_name][random_indices])
            .float()
            .to(self.storage_device)
        )
        if len(scores.shape) == 1:
            segment_length = obss.shape[1]
            scores = (scores / segment_length).unsqueeze(-1) * torch.ones(
                1, segment_length, device=self.storage_device
            )

        if len(obss.shape) > 3:
            # Image has shape (batch, segment, 3, 64, 64) -> (batch, segment, 3*64*64)
            obss = obss.reshape(*actions.shape[:-1], -1)
        obs_action = torch.cat([obss, actions], dim=-1)

        selected_obs_action = obs_action
        selected_rewards = rewards
        obs_action_flat = selected_obs_action.view(-1, obs_action.shape[-1])

        label_matrix = torch.full(
            (obs_action_flat.shape[0], 1), fill_value=label, device=self.storage_device
        )

        obs_action_labelled = torch.cat([obs_action_flat, label_matrix], dim=-1)

        self.demonstrations = selected_obs_action[:num_dem_segments, ...]
        self.dem_segment_size = self.demonstrations.shape[1]

        self.rewards = selected_rewards[:num_dem_segments, ...]
        self.scores = scores[:num_dem_segments, ...]
        self.scores_flat = scores.view(-1, 1)
        if "label" in data.keys():
            labels = data["label"][random_indices]
            self.labels_segment = labels[:num_dem_segments, ...]
            self.labels = self.labels_segment.reshape(-1, 1)
        else:
            self.labels_segment = None
            self.labels = None
        self.demonstrations_labelled = obs_action_labelled[
            : num_dem_segments * self.dem_segment_size, ...
        ]
        self.num_dem_steps = self.demonstrations_labelled.shape[0]
        self.segment_idx = random_indices[:num_dem_segments]

        self.rewards_step = self.rewards.view(-1, 1)

        if self.num_validation_pref > 0:
            self.val_demonstrations = selected_obs_action[
                -2 * self.num_validation_pref :, ...
            ]
            self.val_rewards = selected_rewards[-2 * self.num_validation_pref :, ...]
            self.val_scores = scores[-2 * self.num_validation_pref :, ...]
            self.val_segment_idx = random_indices[-2 * self.num_validation_pref :]
            self.val_demonstrations_labelled = obs_action_labelled[
                -2 * self.num_validation_pref * self.dem_segment_size :, ...
            ]

    def get_pref_policy_data(
        self,
    ):
        if self.pref_sa_dataset is None:
            raise ValueError("No preference data available.")

        random_segment_indices = np.random.permutation(
            self.pref_segment_dataset.shape[0]
        )

        unique_ratio = len(np.unique(random_segment_indices)) / len(
            random_segment_indices
        )
        pref_segment_data = self.pref_segment_dataset[random_segment_indices]
        pref_r_segment_data = self.pref_r_segment_dataset[random_segment_indices]

        pref_sa_data = pref_segment_data.view(-1, pref_segment_data.shape[-1])
        pref_r_sa_data = pref_r_segment_data.view(-1, pref_r_segment_data.shape[-1])

        return (
            pref_sa_data,
            pref_r_sa_data,
            unique_ratio,
            random_segment_indices,
            pref_segment_data,
            pref_r_segment_data,
        )

    def collect_feedback(self, step, agent, reward_model, replay_buffer, logger):

        if step == 0:
            dem_reward_stats = self.calc_data_stats()
            logger.log("dem_stats", dem_reward_stats, step)

            num_pref = self.num_preferences()
            if num_pref == 0:
                stats = self.sample_preferences_offline(
                    agent=agent,
                    reward_model=reward_model,
                    replay_buffer=replay_buffer,
                    num_queries=self.preference_sampling_params.num_prefs,
                    validation_size=self.num_validation_pref,
                )
                logger.log("feedback_number", stats, step)
            logger.log("feedback_number", self.num_feedback(), step)

    def calc_data_stats(self):
        dem_reward_mean = self.rewards_step.mean()
        dem_reward_std = self.rewards_step.std()
        dem_reward_min = self.rewards_step.min()
        dem_reward_max = self.rewards_step.max()

        scores_mean = self.scores.mean()
        scores_std = self.scores.std()
        scores_min = self.scores.min()
        scores_max = self.scores.max()

        dem_reward_stats = {
            "mean": dem_reward_mean,
            "std": dem_reward_std,
            "min": dem_reward_min,
            "max": dem_reward_max,
            "scores_mean": scores_mean,
            "scores_std": scores_std,
            "scores_min": scores_min,
            "scores_max": scores_max,
        }
        return dem_reward_stats

    def add_preferences(
        self,
        option1,
        option2,
        rewards1,
        rewards2,
        validation_size=0,
    ):

        rewards1 = torch.sum(rewards1, dim=1)
        rewards2 = torch.sum(rewards2, dim=1)
        num_queries = option1.shape[0]

        labels = self.evaluate_options(
            option1=option1, option2=option2, rewards1=rewards1, rewards2=rewards2
        )

        if validation_size > 0:
            self.val_demonstrations
            self.val_pref1 = self.val_demonstrations[
                self.val_demonstrations.shape[0] // 2 :, ...
            ]
            self.val_pref2 = self.val_demonstrations[
                : self.val_demonstrations.shape[0] // 2, ...
            ]
            self.val_rewards1 = self.val_rewards[
                self.val_rewards.shape[0] // 2 :, ...
            ].sum(-1, keepdim=True)
            self.val_rewards2 = self.val_rewards[
                : self.val_rewards.shape[0] // 2, ...
            ].sum(-1, keepdim=True)

            self.val_scores1 = self.val_scores[
                self.val_scores.shape[0] // 2 :, ...
            ].sum(-1, keepdim=True)
            self.val_scores2 = self.val_scores[
                : self.val_scores.shape[0] // 2, ...
            ].sum(-1, keepdim=True)

            self.val_labels = self.evaluate_options(
                option1=self.val_pref1,
                option2=self.val_pref2,
                rewards1=self.val_scores1,
                rewards2=self.val_scores2,
            )

        num_queries = num_queries

        if type(option1) == np.ndarray:
            option1 = torch.from_numpy(option1).to(
                self.storage_device, dtype=torch.float32
            )
            option2 = torch.from_numpy(option2).to(
                self.storage_device, dtype=torch.float32
            )

        if (
            self.pref_buffer_last_index + num_queries > self.pref_buffer_capacity
            and self.pref_buffer_full == False
        ):
            self.pref_buffer_full = True
            num_pref_full_buffer = (
                self.pref_buffer_capacity - self.pref_buffer_last_index
            )

            self.pref_1 = torch.cat(
                [self.pref_1, option1[:num_pref_full_buffer, ...]], dim=0
            )
            self.pref_2 = torch.cat(
                [self.pref_2, option2[:num_pref_full_buffer, ...]], dim=0
            )
            self.pref_label = torch.cat(
                [self.pref_label, labels[:num_pref_full_buffer, ...]], dim=0
            )

            rest_pref = num_queries - num_pref_full_buffer
            if rest_pref > self.pref_buffer_capacity:
                raise ValueError(
                    f"Number of queries {num_queries} is larger than the preference buffer capacity {self.pref_buffer_capacity}."
                )

            pref_buffer_indices = [i for i in range(rest_pref)]
            self.pref_1[pref_buffer_indices, ...] = option1[num_pref_full_buffer:, ...]
            self.pref_2[pref_buffer_indices, ...] = option2[num_pref_full_buffer:, ...]
            self.pref_label[pref_buffer_indices, ...] = labels[
                num_pref_full_buffer:, ...
            ]

        elif self.pref_buffer_full:
            pref_buffer_indices = [
                int((self.pref_buffer_last_index + i) % self.pref_buffer_capacity)
                for i in range(num_queries)
            ]
            self.pref_1[pref_buffer_indices, ...] = option1
            self.pref_2[pref_buffer_indices, ...] = option2
            self.pref_label[pref_buffer_indices, ...] = labels
            self.pref_buffer_last_index = pref_buffer_indices[-1] + 1

        else:
            self.pref_buffer_last_index += num_queries

            self.pref_1 = torch.cat([self.pref_1, option1], dim=0)
            self.pref_2 = torch.cat([self.pref_2, option2], dim=0)
            self.pref_label = torch.cat([self.pref_label, labels], dim=0)

    def sample_preferences_sparse(
        self, agent, reward_model, replay_buffer, num_queries, validation_size=0
    ):

        if self.cpl_sampling:  # Compare first half with the second
            option1, option2, rewards1, rewards2, query_stats, segment_indices = (
                self.get_queries_cpl(
                    num_queries=num_queries,
                    size_segment=self.preference_sampling_params.pref_segment_size,
                )
            )

        else:

            option1, option2, rewards1, rewards2, query_stats, segment_indices = (
                self.get_queries(
                    inputs=self.demonstrations,
                    targets=self.scores,
                    num_queries=num_queries,
                    size_segment=self.preference_sampling_params.pref_segment_size,
                )
            )

        self.add_preferences(
            option1=option1,
            option2=option2,
            rewards1=rewards1,
            rewards2=rewards2,
            validation_size=validation_size,
        )

        if (
            self.preference_sampling_params.pref_segment_size
            == self.demonstrations.shape[1]
        ):  # whole segments
            merged_indices = np.concatenate(
                [segment_indices["index1"], segment_indices["index2"]]
            )

            segment_indices = np.unique(merged_indices)
            self.pref_segment_dataset = self.demonstrations[segment_indices]
            self.pref_r_segment_dataset = self.scores[segment_indices].unsqueeze(-1)

            self.pref_sa_dataset = self.pref_segment_dataset.view(
                -1, self.pref_segment_dataset.shape[-1]
            )
            self.pref_r_dataset = self.pref_r_segment_dataset.view(
                -1, self.pref_r_segment_dataset.shape[-1]
            )
        else:
            merged_indices = np.concatenate(
                [segment_indices["index1"], segment_indices["index2"]]
            )
            merged_start = np.concatenate(
                [segment_indices["start_1"], segment_indices["start_2"]]
            )
            merged_segments = np.concatenate([merged_indices, merged_start], axis=-1)

            unique_seg = np.unique(merged_segments, axis=0)

            self.pref_segment_dataset = self.demonstrations[unique_seg, :]
            self.pref_r_segment_dataset = self.scores[unique_seg].unsqueeze(-1)

            self.pref_sa_dataset = self.pref_segment_dataset.view(
                -1, self.pref_segment_dataset.shape[-1]
            )
            self.pref_r_dataset = self.pref_r_segment_dataset.view(
                -1, self.pref_r_segment_dataset.shape[-1]
            )

        return query_stats

    def get_queries(self, inputs, targets, num_queries, size_segment):

        max_len = inputs.shape[0]
        len_traj = inputs.shape[1]

        batch_index_1 = np.random.choice(max_len, size=num_queries, replace=True)
        batch_index_2 = np.random.choice(max_len, size=num_queries, replace=True)
        for i in range(num_queries):
            while batch_index_1[i] == batch_index_2[i]:
                batch_index_2[i] = np.random.choice(max_len)

        index1_unique = np.unique(batch_index_1)
        index2_unique = np.unique(batch_index_2)
        merged_index = np.unique(np.concatenate((batch_index_1, batch_index_2)))
        pref_coverage_stats = {
            "index1_unique": len(index1_unique) / max_len,
            "index2_unique": len(index2_unique) / max_len,
            "merged": len(merged_index) / max_len,
        }
        segment_indices = {"index1": batch_index_1, "index2": batch_index_2}

        traj_1 = inputs[batch_index_1]  # Batch x T x dim of s&a
        r_t_1 = targets[batch_index_1]  # Batch x T x 1

        traj_2 = inputs[batch_index_2]  # Batch x T x dim of s&a
        r_t_2 = targets[batch_index_2]  # Batch x T x 1

        if len_traj == size_segment:
            option1 = traj_1
            rewards1 = r_t_1.unsqueeze(-1)
            option2 = traj_2
            rewards2 = r_t_2.unsqueeze(-1)
        else:

            # random start indices in [0, len_traj - size_segment] inclusive
            random_offsets_1 = torch.randint(
                0,
                len_traj - size_segment + 1,
                (num_queries,),
                device=self.storage_device,
            )
            random_offsets_2 = torch.randint(
                0,
                len_traj - size_segment + 1,
                (num_queries,),
                device=self.storage_device,
            )

            # [N, 1]
            time_index_1 = random_offsets_1.view(-1, 1)
            time_index_2 = random_offsets_2.view(-1, 1)

            # [T] and then broadcast to [N, T]
            t_range = torch.arange(size_segment, device=self.storage_device)
            idx_1 = time_index_1 + t_range  # [N, T]
            idx_2 = time_index_2 + t_range  # [N, T]

            # Gather traj segments: expand to [N, T, D]
            idx_1_traj = idx_1.unsqueeze(-1).expand(
                -1, -1, traj_1.size(-1)
            )  # [N, T, D]
            idx_2_traj = idx_2.unsqueeze(-1).expand(
                -1, -1, traj_2.size(-1)
            )  # [N, T, D]

            option1 = torch.gather(traj_1, 1, idx_1_traj)  # [N, T, D]
            option2 = torch.gather(traj_2, 1, idx_2_traj)  # [N, T, D]

            # Gather rewards segments: handle both [N, L] and [N, L, 1]
            if r_t_1.dim() == 2:
                rewards1 = torch.gather(
                    r_t_1.to(self.storage_device), 1, idx_1
                )  # [N, T]
            else:
                rewards1 = torch.gather(
                    r_t_1.to(self.storage_device), 1, idx_1.unsqueeze(-1)
                )  # [N, T, 1]

            if r_t_2.dim() == 2:
                rewards2 = torch.gather(
                    r_t_2.to(self.storage_device), 1, idx_2
                )  # [N, T]
            else:
                rewards2 = torch.gather(
                    r_t_2.to(self.storage_device), 1, idx_2.unsqueeze(-1)
                )  # [N, T, 1]

            if rewards1.dim() == 2:
                rewards1 = rewards1.unsqueeze(-1)
                rewards2 = rewards2.unsqueeze(-1)

            segment_indices["start_1"] = time_index_1.squeeze(-1).cpu().numpy()
            segment_indices["start_2"] = time_index_2.squeeze(-1).cpu().numpy()

        stats = {"pref_coverage_stats": pref_coverage_stats}
        return option1, option2, rewards1, rewards2, stats, segment_indices

    def get_queries_cpl(self, num_queries, size_segment):

        train_inputs = self.demonstrations
        train_targets = self.scores

        max_len = train_inputs.shape[0]
        num_prefs = max_len // 2

        len_traj = train_inputs.shape[1]

        first_idx = np.arange(num_prefs)

        batch_index_1 = np.random.choice(first_idx, num_prefs, replace=False)

        batch_index_2 = batch_index_1 + num_prefs

        segment_indecies = {"index1": batch_index_1, "index2": batch_index_2}

        sa_t_1 = train_inputs[batch_index_1]  # Batch x T x dim of s&a
        r_t_1 = train_targets[batch_index_1]  # Batch x T x 1

        sa_t_2 = train_inputs[batch_index_2]  # Batch x T x dim of s&a
        r_t_2 = train_targets[batch_index_2]  # Batch x T x 1

        if len_traj == size_segment:
            option1 = sa_t_1
            rewards1 = r_t_1.unsqueeze(-1)
            option2 = sa_t_2
            rewards2 = r_t_2.unsqueeze(-1)
        else:
            time_index = torch.tensor(
                [
                    list(range(i * len_traj, i * len_traj + size_segment))
                    for i in range(num_queries)
                ]
            )

            # Random offsets for `time_index_2` and `time_index_1`
            random_offsets_2 = torch.randint(0, len_traj - size_segment, (num_queries,))
            random_offsets_1 = torch.randint(0, len_traj - size_segment, (num_queries,))

            # Add offsets to time_index
            time_index_2 = time_index + random_offsets_2.view(-1, 1)
            time_index_1 = time_index + random_offsets_1.view(-1, 1)
            time_index_2 = time_index_2.to(sa_t_2.device)
            time_index_1 = time_index_1.to(sa_t_1.device)

            option1 = torch.gather(
                sa_t_1, 1, time_index_1.unsqueeze(-1).expand(-1, -1, sa_t_1.size(-1))
            )
            rewards1 = torch.gather(r_t_1, 1, time_index_1.unsqueeze(-1))

            option2 = torch.gather(
                sa_t_2, 1, time_index_2.unsqueeze(-1).expand(-1, -1, sa_t_2.size(-1))
            )
            rewards2 = torch.gather(r_t_2, 1, time_index_2.unsqueeze(-1))

        stats = {}
        return option1, option2, rewards1, rewards2, stats, segment_indecies

    def evaluate_options(self, option1, option2, rewards1, rewards2):
        labels = (rewards1 <= rewards2).float()

        return labels

    def num_preferences(self):
        if self.is_dense_pref:
            return len(self.dense_pref_pairs)
        else:
            return (
                self.pref_buffer_capacity
                if self.pref_buffer_full
                else self.pref_buffer_last_index
            )

    def num_feedback(self):
        num_feedback_dic = {
            "num_preferences": self.num_preferences(),
        }
        return num_feedback_dic

    def pref_data_analysis(self, indices, labels):
        labels = labels.squeeze().cpu().numpy()
        option_indices_1 = indices["index1"]
        option_indices_2 = indices["index2"]
        better_option1 = labels == 0
        better_option2 = labels == 1
        better_option1_idx = option_indices_1[better_option1]
        better_option2_idx = option_indices_2[better_option2]
        worse_option1_idx = option_indices_1[better_option2]
        worse_option2_idx = option_indices_2[better_option1]
        better_idx = np.concatenate([better_option1_idx, better_option2_idx])
        worse_idx = np.concatenate([worse_option1_idx, worse_option2_idx])

        unique_better = np.unique(better_idx)
        unique_worse = np.unique(worse_idx)

        num_same_elements = np.sum(np.isin(better_idx, worse_idx))

        unique_better_ratio = len(unique_better) / len(better_idx)
        unique_worse_ratio = len(unique_worse) / len(worse_idx)
        same_elements_ratio = num_same_elements / len(better_idx)

        pref_data_stats = {
            "unique_better_ratio": unique_better_ratio,
            "unique_worse_ratio": unique_worse_ratio,
            "same_elements_ratio": same_elements_ratio,
        }

        return pref_data_stats

    def segment_log_pi(self, agent, segment):
        obs = segment[:, : self.env_obs_dim].to(self.device)
        actions = segment[:, self.env_obs_dim :].to(self.device)
        with torch.no_grad():
            dist = agent.actor(obs)
            log_prob = dist.log_prob(actions)
        return log_prob.sum().item()

    def get_preference_data(self, idxs=None):
        if self.is_dense_pref:
            return self.get_preference_data_dense(idxs=idxs)
        else:
            return self.get_preference_data_sparse(idxs=idxs)

    def get_preference_data_sparse(self, idxs=None):
        if idxs is not None:
            pref_1 = self.pref_1[idxs, ...]
            pref_2 = self.pref_2[idxs, ...]
            pref_label = self.pref_label[idxs, ...]
        else:
            pref_1 = self.pref_1
            pref_2 = self.pref_2
            pref_label = self.pref_label
        return (
            pref_1,
            pref_2,
            pref_label,
        )

    def sample_preferences_dense(
        self, agent, reward_model, replay_buffer, num_queries, validation_size=0
    ):

        if len(self.scores.shape) == 2:
            score_metric = self.scores.sum(axis=1)
        else:
            score_metric = self.scores

        self.traj_rankings = np.argsort(-score_metric.cpu().numpy())

        pref_pairs = []
        for i in range(len(self.traj_rankings)):
            for j in range(i + 1, len(self.traj_rankings)):
                pref_pairs.append((self.traj_rankings[i], self.traj_rankings[j]))
        self.dense_pref_pairs = pref_pairs

        self.dense_pref_pairs1 = [
            self.dense_pref_pairs[i][0] for i in range(len(pref_pairs))
        ]
        self.dense_pref_pairs2 = [
            self.dense_pref_pairs[i][1] for i in range(len(pref_pairs))
        ]

        pref_sampling_stats = {
            "pref_pairs": len(pref_pairs),
        }

        self.pref_sa_dataset = self.demonstrations.view(
            -1, self.demonstrations.size(-1)
        )
        self.pref_r_dataset = self.scores_flat

        self.pref_segment_dataset = self.demonstrations

        self.pref_r_segment_dataset = self.scores
        if self.pref_r_segment_dataset.shape[-1] != 1:
            self.pref_r_segment_dataset = self.pref_r_segment_dataset.unsqueeze(-1)

        return pref_sampling_stats

    def get_preference_data_dense(self, idxs=None):
        if idxs is None:
            idxs = np.random.choice(
                len(self.dense_pref_pairs),
                size=self.pref_buffer_capacity,
                replace=False,
            )
        selected_pairs = [self.dense_pref_pairs[i] for i in idxs]
        selected_pref_1 = [selected_pairs[i][0] for i in range(len(selected_pairs))]
        selected_pref_2 = [selected_pairs[i][1] for i in range(len(selected_pairs))]
        pref_1 = self.demonstrations[selected_pref_1, ...]
        pref_2 = self.demonstrations[selected_pref_2, ...]
        pref_label = torch.zeros((pref_1.shape[0], 1), device=self.device)

        return (
            pref_1,
            pref_2,
            pref_label,
        )

    def sample_preferences_offline(
        self, agent, reward_model, replay_buffer, num_queries, validation_size=0
    ):
        if self.is_dense_pref:
            return self.sample_preferences_dense(
                agent=agent,
                reward_model=reward_model,
                replay_buffer=replay_buffer,
                num_queries=num_queries,
                validation_size=validation_size,
            )
        else:
            return self.sample_preferences_sparse(
                agent=agent,
                reward_model=reward_model,
                replay_buffer=replay_buffer,
                num_queries=num_queries,
                validation_size=validation_size,
            )
