import torch
from torch.utils.data import Dataset
import numpy as np
import copy
import zarr
import random
from torch.utils.data import DataLoader

from diffusion_policy.common.pytorch_util import dict_apply
from diffusion_policy.model.common.normalizer import LinearNormalizer
from diffusion_policy.dataset.base_dataset import BaseImageDataset
from diffusion_policy.common.normalize_util import get_image_range_normalizer
from diffusion_policy.common.sampler import (
    SequenceSampler, get_val_mask, downsample_mask)


def create_indices_one_step(episode_ends, mask):
    indices = []
    trajectory_index = []
    for i in range(0, len(episode_ends) - 1):
        if mask[i]:
            start = episode_ends[i]
            end = episode_ends[i + 1] - 1
            indices.extend([(j, j + 1) for j in range(start, end)])
            # convenient for indexing
            indices.append([end - 1, end])
            trajectory_index.append(i for _ in range(start, end))
    return indices, trajectory_index

class DynamicsModelDataset(torch.utils.data.Dataset):
    def __init__(self, zarr_path, val_ratio=0.2, random=True):
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        episode_ends = zarr_data_np['meta']['episode_ends']  # Extract episode ends
        self.imgs = zarr_data_np['data']['img']
        self.states = zarr_data_np['data']['state']
        # to fix the previous bug in pusht_env
        self.states[:, 4] = self.states[:, 4] % (2 * np.pi) 
        self.actions = zarr_data_np['data']['action']
        self.n_contacts = zarr_data_np['data']['n_contacts']
        self.episode_ends = np.insert(episode_ends, 0, 0)  # Add 0 at the beginning
        self.val_ratio = val_ratio
        # print('episode_ends ', self.episode_ends)

        val_mask = get_val_mask(
            n_episodes=len(self.episode_ends),
            val_ratio=val_ratio,
            seed=42,
            random=random)
        self.train_mask = ~val_mask
        # print(self.train_mask)
        self.valid_transitions, self.trajectory_index = create_indices_one_step(self.episode_ends, self.train_mask)
        # print('length of valid_transitions ', len(self.valid_transitions))
        # print(self.episode_ends[self.train_mask==True])

    def get_validation_dataset(self):
        val_set = copy.copy(self)
        val_set.valid_transitions = create_indices_one_step(self.episode_ends, ~self.train_mask)
        val_set.train_mask = ~self.train_mask
        return val_set
    
    def __len__(self):
        # Return the number of valid transitions
        return len(self.valid_transitions)

    def __getitem__(self, idx):
        idx_t, idx_t1 = self.valid_transitions[idx]
        img_t = self.imgs[idx_t]  # (96, 96, 3)
        img_t = np.moveaxis(img_t, -1, 0)/255
        state_t = self.states[idx_t]
        state_t = state_t.astype(np.float32)
        agent_pos_t = self.states[idx_t][:2]
        agent_pos_t = agent_pos_t.astype(np.float32)
        n_contacts_t = self.n_contacts[idx_t]
        n_contacts_t = n_contacts_t.astype(np.float32)


        img_t1 = self.imgs[idx_t1]
        img_t1 = np.moveaxis(img_t1, -1, 0)/255
        state_t1 = self.states[idx_t1]
        state_t1 = state_t1.astype(np.float32)
        agent_pos_t1 = self.states[idx_t1][:2]
        agent_pos_t1 = agent_pos_t1.astype(np.float32)
        n_contacts_t1 = self.n_contacts[idx_t1]
        n_contacts_t1 = n_contacts_t1.astype(np.float32)

        # Get action at time t
        action_t = self.actions[idx_t]

        # Prepare the final observation tuple (ot, at, o_{t+1})
        # o_t = (img_t, state_t)
        # o_t1 = (img_t1, state_t1)
        data = {
            'o_t': {
                'image': img_t, 
                'state': np.concatenate([state_t, n_contacts_t]),
                'agent_pos': agent_pos_t
            },
            'action': action_t,
            'o_t1': {
                'image': img_t1, 
                'state': np.concatenate([state_t1, n_contacts_t1]),
                'agent_pos': agent_pos_t1, 
            }
        }
        torch_data = dict_apply(data, torch.from_numpy)
        return torch_data
    

    def get_normalizer(self, mode='limits', **kwargs):
        data = {
            'action': self.actions,
            'state': np.concatenate((self.states, self.n_contacts), axis=1),
            'agent_pos': self.states[...,:2],
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        normalizer['image'] = get_image_range_normalizer()
        return normalizer


class MultiStepDynamicsModelDataset(Dataset):
    def __init__(self, zarr_path, horizon=1, val_ratio=0.0):
        # Load zarr data
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        
        # Extract episode ends (0-indexed)
        episode_ends = zarr_data_np['meta']['episode_ends']  # e.g., [60, 110, ...]

        # Extract and process states and actions
        self.imgs = zarr_data_np['data']['img']
        self.states = zarr_data_np['data']['state']  # Shape: (total_timesteps, state_dim)
        self.actions = zarr_data_np['data']['action']  # Shape: (total_timesteps, action_dim)
        self.n_contacts = zarr_data_np['data']['n_contacts']  # Shape: (total_timesteps,)

        # Apply modulo operation to the 5th state dimension as per starter code
        self.states[:, 4] = self.states[:, 4] % (2 * np.pi) 

        # Append n_contacts to states
        self.states = np.concatenate((self.states, self.n_contacts), axis=1)  # Shape: (total_timesteps, state_dim +1)

        self.horizon = horizon
        self.val_ratio = val_ratio

        val_mask = get_val_mask(
            n_episodes=len(episode_ends),
            val_ratio=val_ratio,
            seed=42,
            random=False)
        self.train_mask = ~val_mask
        self.episode_ends = episode_ends[self.train_mask]
        # print('episode_ends ', self.episode_ends)
        # Process episode ends to get trajectory boundaries
        self.trajectories = []  # List of tuples: (start_idx, end_idx)
        prev_end = 0
        for end in self.episode_ends:
            traj_end = end
            self.trajectories.append((prev_end, traj_end))
            prev_end = end

        # Precompute sample indices as (traj_idx, t) tuples
        self.sample_indices = []
        for traj_idx, (start, end) in enumerate(self.trajectories):
            traj_length = end - start 
            for t in range(traj_length):
                self.sample_indices.append((traj_idx, t))


    def __len__(self):
        # Return the number of valid transitions
        return len(self.sample_indices)

    def __getitem__(self, idx):
        traj_idx, t = self.sample_indices[idx]
        # print('traj_idx', traj_idx)

        # Get trajectory boundaries
        start, end = self.trajectories[traj_idx]

        # Current state index
        s_t_idx = start + t  # Current state index
        s_t = self.states[s_t_idx]  # Shape: (state_dim +1,)
        img_t = self.imgs[s_t_idx]  # Shape: (96, 96, 3)
        # Actions indices: a_t to a_{t + horizon -1}
        a_start = start + t
        a_end = a_start + self.horizon  # Exclusive

        # Future states indices: s_{t +1} to s_{t + horizon}
        s_future_start = s_t_idx + 1
        s_future_end = s_future_start + self.horizon  # Exclusive

        # Initialize placeholders for actions and future states
        actions_seq = []
        s_future_seq = []
        img_future_seq = []
        # probe_idx = 356
        # if t == probe_idx and traj_idx == 1:
        #     print('s_t_idx', s_t_idx)
        #     print('s_t', s_t)
        #     print('a_start', a_start)
        #     print('a_end', a_end)
        #     print('s_future_start', s_future_start)
        #     print('s_future_end', s_future_end)
        #     print('end', end)
        # Handle actions sequence with padding if necessary
        if a_end <= end:
            # Sufficient actions available
            actions_seq = self.actions[a_start:a_end]
        else:
            # Not enough actions; pad with the last action
            available_actions = end - a_start
            if available_actions > 0:
                actions_seq = self.actions[a_start:end]
                pad_length = self.horizon - available_actions
                last_action = self.actions[end -1]
                # if t == probe_idx:
                #     print("action padded index ", end -1)
                #     print('pad_length', pad_length)
                pad_actions = np.tile(last_action, (pad_length, 1))
                actions_seq = np.vstack((actions_seq, pad_actions))
            else:
                # No actions available; repeat the last action
                last_action = self.actions[end -1]
                actions_seq = np.tile(last_action, (self.horizon, 1))

        # Handle future states sequence with padding if necessary
        if s_future_end <= end:
            # Sufficient future states available
            s_future_seq = self.states[s_future_start:s_future_end]
            img_future_seq = self.imgs[s_future_start:s_future_end]
        else:
            # Not enough future states; pad with the last state
            available_states = end - s_future_start
            if available_states > 0:
                s_future_seq = self.states[s_future_start:end]
                img_future_seq = self.imgs[s_future_start:end]
                pad_length = self.horizon - available_states
                last_state = self.states[end -1]
                # if t == probe_idx:
                #     print("future state padded index ", end -1)
                #     print('pad_length', pad_length)

                pad_states = np.tile(last_state, (pad_length, 1))
                pad_imgs = np.tile(self.imgs[end -1], (pad_length, 1, 1, 1))
                s_future_seq = np.vstack((s_future_seq, pad_states))
                img_future_seq = np.vstack((img_future_seq, pad_imgs))

            else:
                # No future states available; repeat the last state
                last_state = self.states[end -1]
                s_future_seq = np.tile(last_state, (self.horizon, 1))
                last_img = self.imgs[end -1]
                img_future_seq = np.tile(last_img, (self.horizon, 1, 1, 1))

        # Convert to torch tensors
        s_t = torch.tensor(s_t, dtype=torch.float32)  # Shape: (state_dim +1,)
        actions_seq = torch.tensor(actions_seq, dtype=torch.float32)  # Shape: (horizon, action_dim)
        s_future_seq = torch.tensor(s_future_seq, dtype=torch.float32)  # Shape: (horizon, state_dim +1)
        img_t = torch.tensor(img_t, dtype=torch.float32)  # Shape: (3, 96, 96)
        data = {
            's_t': s_t,
            'img_t': img_t,
            'actions_seq': actions_seq,
            's_future_seq': s_future_seq,
            'img_future_seq': img_future_seq
        }

        # negative sampling
        negative_candidates = []
        for i in range(len(self.trajectories)):
            if self.trajectories[i][0] != start:
                # print('self.trajectories[{}][0] {} start {}'.format(i, self.trajectories[i][0], start))
                negative_candidates = np.concatenate((negative_candidates, np.arange(self.trajectories[i][0], self.trajectories[i][1])))
        negative_candidates = negative_candidates.astype(int)
        negative_t = np.random.choice(negative_candidates, size=5, replace=False)
        # print("negative_candidates ", negative_candidates)
        ### sanity check ###
        for id in negative_t:
            # print('id', id)
            # print('neg state traj id {} t {}'.format(self.sample_indices[id][0], self.sample_indices[id][1]))
            neg_traj_idx, neg_t = self.sample_indices[id]
            if self.trajectories[neg_traj_idx][0] == start:
                raise ValueError('Negative samples are from the same trajectory')
        negative_states = self.states[negative_t]
        data['negative_states'] = torch.tensor(negative_states, dtype=torch.float32)
        return data

    def get_normalizer(self, mode='limits', **kwargs):
        data = {
            'action': self.actions,
            'state': self.states,
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        normalizer['image'] = get_image_range_normalizer()
        return normalizer
    

class MultiStepImageDynamicsModelDataset(Dataset):
    def __init__(self, zarr_path, horizon=1, val_ratio=0.0, n_neg=5):
        # Load zarr data
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        
        # Extract episode ends (0-indexed)
        episode_ends = zarr_data_np['meta']['episode_ends']  # e.g., [60, 110, ...]

        # Extract and process states and actions
        self.imgs = zarr_data_np['data']['img']
        self.states = zarr_data_np['data']['state']  # Shape: (total_timesteps, state_dim)
        self.actions = zarr_data_np['data']['action']  # Shape: (total_timesteps, action_dim)
        self.n_contacts = zarr_data_np['data']['n_contacts']  # Shape: (total_timesteps,)

        # Apply modulo operation to the 5th state dimension as per starter code
        self.states[:, 4] = self.states[:, 4] % (2 * np.pi) 

        # Append n_contacts to states
        self.states = np.concatenate((self.states, self.n_contacts), axis=1)  # Shape: (total_timesteps, state_dim +1)

        self.horizon = horizon
        self.val_ratio = val_ratio

        val_mask = get_val_mask(
            n_episodes=len(episode_ends),
            val_ratio=val_ratio,
            seed=42,
            random=False)
        self.train_mask = ~val_mask
        print('self.train_mask ', self.train_mask)
        self.original_episode_ends = episode_ends
        self.episode_ends = episode_ends[self.train_mask]
        # print('episode_ends ', self.episode_ends)
        # Process episode ends to get trajectory boundaries
        self.trajectories = []  # List of tuples: (start_idx, end_idx)

        prev_end = 0

        if self.train_mask[0]:
            prev_end = 0
        else:
            tmp = np.where(self.train_mask == True)[0][0]
            # print('tmp ', tmp)
            prev_end = episode_ends[tmp - 1]

        for end in self.episode_ends:
            traj_end = end
            self.trajectories.append((prev_end, traj_end))
            prev_end = end
        print('self.trajectories ', self.trajectories)
        # Precompute sample indices as (traj_idx, t) tuples
        self.sample_indices = []
        for traj_idx, (start, end) in enumerate(self.trajectories):
            traj_length = end - start - 6
            for t in range(traj_length):
                self.sample_indices.append((traj_idx, t))
        # print('sample_indices ', self.sample_indices)

        self.n_neg = n_neg

    def __len__(self):
        # Return the number of valid transitions
        return len(self.sample_indices)

    def __getitem__(self, idx):
        traj_idx, t = self.sample_indices[idx]
        # print('traj_idx', traj_idx)

        # Get trajectory boundaries
        start, end = self.trajectories[traj_idx]
        # print('traj start', start)

        # Current state index
        s_t_idx = start + t  # Current state index
        s_t = self.states[s_t_idx]  # Shape: (state_dim +1,)
        img_t = self.imgs[s_t_idx]  # Shape: (96, 96, 3)
        # Actions indices: a_t to a_{t + horizon -1}
        a_start = start + t
        a_end = a_start + self.horizon  # Exclusive

        # Future states indices: s_{t +1} to s_{t + horizon}
        s_future_start = s_t_idx + 1
        s_future_end = s_future_start + self.horizon  # Exclusive

        # Initialize placeholders for actions and future states
        actions_seq = []
        s_future_seq = []
        img_future_seq = []

        # Handle actions sequence with padding if necessary
        if a_end <= end:
            # Sufficient actions available
            actions_seq = self.actions[a_start:a_end]
            # print('actions_seq ', actions_seq.shape)
        else:
            # Not enough actions; pad with the last action
            available_actions = end - a_start
            if available_actions > 0:
                actions_seq = self.actions[a_start:end]
                pad_length = self.horizon - available_actions
                last_action = self.actions[end -1]
                pad_actions = np.tile(last_action, (pad_length, 1))
                actions_seq = np.vstack((actions_seq, pad_actions))
            else:
                # No actions available; repeat the last action
                last_action = self.actions[end -1]
                actions_seq = np.tile(last_action, (self.horizon, 1))

        # Handle future states sequence with padding if necessary
        if s_future_end <= end:
            # Sufficient future states available
            s_future_seq = self.states[s_future_start:s_future_end]
            img_future_seq = self.imgs[s_future_start:s_future_end]

        else:
            # Not enough future states; pad with the last state
            available_states = end - s_future_start
            if available_states > 0:
                s_future_seq = self.states[s_future_start:end]
                img_future_seq = self.imgs[s_future_start:end]
                pad_length = self.horizon - available_states
                last_state = self.states[end -1]

                pad_states = np.tile(last_state, (pad_length, 1))
                pad_imgs = np.tile(self.imgs[end -1], (pad_length, 1, 1, 1))
                s_future_seq = np.vstack((s_future_seq, pad_states))
                img_future_seq = np.vstack((img_future_seq, pad_imgs))

            else:
                # No future states available; repeat the last state
                last_state = self.states[end -1]
                s_future_seq = np.tile(last_state, (self.horizon, 1))
                last_img = self.imgs[end -1]
                img_future_seq = np.tile(last_img, (self.horizon, 1, 1, 1))

        # Convert to torch tensors
        s_t = torch.tensor(s_t, dtype=torch.float32)  # Shape: (state_dim +1,)
        actions_seq = torch.tensor(actions_seq, dtype=torch.float32)  # Shape: (horizon, action_dim)
        s_future_seq = torch.tensor(s_future_seq, dtype=torch.float32)  # Shape: (horizon, state_dim +1)
        img_t = np.moveaxis(img_t,-1,0)/255
        img_t = torch.tensor(img_t, dtype=torch.float32)  # Shape: (96, 96, 3)
        img_future_seq = np.moveaxis(img_future_seq,-1,1)/255
        img_future_seq = torch.tensor(img_future_seq, dtype=torch.float32)  # Shape: (horizon, 96, 96, 3)

        data = {
            's_t': s_t,
            'img_t': img_t,
            'actions_seq': actions_seq,
            's_future_seq': s_future_seq,
            'img_future_seq': img_future_seq
        }

        # negative sampling
        negative_candidates = []
        for i in range(len(self.trajectories)):
            if self.trajectories[i][0] != start:
                # print('self.trajectories[{}][0] {} start {}'.format(i, self.trajectories[i][0], start))
                negative_candidates = np.concatenate((negative_candidates, np.arange(self.trajectories[i][0], self.trajectories[i][1])))
        negative_candidates = negative_candidates.astype(int)
        negative_t = np.random.choice(negative_candidates, size=self.n_neg, replace=False)

        # if s_t_idx < end - 50:
        #     p = np.random.rand()
        #     if p < 0.05:
        #         # print('this branch')
        #         negative_t[-1] = end - 1

        # print("negative_candidates ", negative_candidates)
        # ### sanity check ###
        # cur_start = np.searchsorted(self.original_episode_ends, s_t_idx, side='right')
        # for id in negative_t:
        #     # print('id', id)
        #     # print('neg state traj id {} t {}'.format(self.sample_indices[id][0], self.sample_indices[id][1]))
        #     neg_start = np.searchsorted(self.original_episode_ends, id, side='right')
        #     print('cur_start {} neg_start {} '.format(cur_start, neg_start))
        #     if cur_start == neg_start:
        #         raise ValueError('Negative samples are from the same trajectory')
        # ### sanity check end ###
        # print('negative_t   ', negative_t)
        # print('img future index ', np.arange(s_future_start, s_future_end))
        negative_imgs = self.imgs[negative_t]
        negative_imgs = np.moveaxis(negative_imgs,-1,1)/255
        data['negative_imgs'] = torch.tensor(negative_imgs, dtype=torch.float32)

        return data

    def get_normalizer(self, mode='limits', **kwargs):
        data = {
            'action': self.actions,
            'state': self.states,
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        normalizer['image'] = get_image_range_normalizer()
        return normalizer
    

class MultiStepWithHistoryImageDynamicsModelDataset(Dataset):
    def __init__(self, zarr_path, horizon=1, val_ratio=0.0, n_neg=5, neg_sampling='other_ep_as_neg'):
        """
        Initializes the dataset by loading data from a Zarr file and precomputing valid anchor indices.
        
        Args:
            zarr_path (str): Path to the Zarr dataset.
            horizon (int): Number of steps for history and future.
            val_ratio (float): Fraction of episodes to use for validation.
            n_neg (int): Number of negative samples (unused in this implementation).
        """
        # Load zarr data
        zarr_data = zarr.open(zarr_path, mode='r')
        zarr_data_np = {
            'data': {key: np.array(zarr_data['data'][key]) for key in zarr_data['data'].keys()},
            'meta': {key: np.array(zarr_data['meta'][key]) for key in zarr_data['meta'].keys()}
        }
        
        # Extract episode ends (1-indexed)
        episode_ends = zarr_data_np['meta']['episode_ends']  # e.g., [60, 110, ...]
        n_episodes = len(episode_ends)
        
        # Extract and process states and actions
        self.imgs = zarr_data_np['data']['img']
        self.states = zarr_data_np['data']['state']  # Shape: (total_timesteps, state_dim)
        self.actions = zarr_data_np['data']['action']  # Shape: (total_timesteps, action_dim)
        self.n_contacts = zarr_data_np['data']['n_contacts']  # Shape: (total_timesteps,)
        
        # Apply modulo operation to the 5th state dimension as per starter code
        if self.states.shape[1] > 4:
            self.states[:, 4] = self.states[:, 4] % (2 * np.pi) 
        
        # Append n_contacts to states
        self.states = np.concatenate((self.states, self.n_contacts.reshape(-1, 1)), axis=1)  # Shape: (total_timesteps, state_dim +1)
        
        self.horizon = horizon
        self.val_ratio = val_ratio
        self.n_neg = n_neg  # Currently unused
        self.neg_sampling = neg_sampling
        
        # Generate validation mask
        val_mask = get_val_mask(
            n_episodes=n_episodes,
            val_ratio=val_ratio,
            seed=42,
            random=False)
        self.train_mask = ~val_mask
        self.episode_ends_train = episode_ends[self.train_mask]
        print('episode_ends_train ', self.episode_ends_train)
        # Convert episode_ends to zero-indexed format and store the start and end indices of each trajectory
        self.episode_start_indices = np.concatenate(([0], self.episode_ends_train[:-1]))
        self.episode_end_indices = self.episode_ends_train - 1  # last index of each trajectory
        
        # Precompute valid anchor indices
        self.valid_anchor_indices = []
        for start, end in zip(self.episode_start_indices, self.episode_end_indices):
            # Valid anchors are from start + horizon to end - horizon
            anchor_start = start + self.horizon
            anchor_end = end - self.horizon
            if anchor_end >= anchor_start:
                anchors = np.arange(anchor_start, anchor_end + 1)
                self.valid_anchor_indices.extend(anchors)
        self.valid_anchor_indices = np.array(self.valid_anchor_indices)
        # print('valid_anchor_indices ', self.valid_anchor_indices)
        self.num_valid = len(self.valid_anchor_indices)
    
    def __len__(self):
        """
        Returns the number of valid anchor samples.
        """
        return self.num_valid
    
    def __getitem__(self, idx):
        """
        Retrieves the sample at the given index.
        
        Args:
            idx (int): Index of the sample to retrieve.
        
        Returns:
            dict: A dictionary containing grouped history and future states/actions.
        """
        if idx < 0 or idx >= self.num_valid:
            raise IndexError(f"Index {idx} out of range for dataset of size {self.num_valid}")
        
        anchor_idx = self.valid_anchor_indices[idx]
        
        # Define indices for history and future
        history_start = anchor_idx - self.horizon
        history_end = anchor_idx  # Inclusive
        future_start = anchor_idx + 1
        future_end = anchor_idx + self.horizon  # Inclusive
        # Extract history states and actions
        history_imgs = self.imgs[history_start:history_end + 1]  # Shape: (horizon +1, state_dim +1)
        history_actions = self.actions[history_start:anchor_idx]  # Shape: (horizon, action_dim)
        
        # Extract future actions and states
        future_actions = self.actions[anchor_idx:future_end]  # Shape: (horizon, action_dim)
        future_imgs = self.imgs[future_start:future_end + 1]  # Shape: (horizon, state_dim +1)
        
        # Extract history images (s_{t-h} to s_t)

        # Convert all to tensors
        history_imgs = torch.tensor(np.moveaxis(history_imgs,-1,1)/255, dtype=torch.float32)
        history_actions = torch.tensor(history_actions, dtype=torch.float32)
        future_actions = torch.tensor(future_actions, dtype=torch.float32)
        future_imgs = torch.tensor(np.moveaxis(future_imgs,-1,1)/255, dtype=torch.float32)
        # =====================
        # # Sanity Check
        # # =====================
        # # Determine which episode the anchor_idx belongs to
        # episode_idx = np.searchsorted(self.episode_end_indices, anchor_idx, side='right')
        
        # # Retrieve the start and end indices of this episode
        # episode_start = self.episode_start_indices[episode_idx]
        # episode_end = self.episode_end_indices[episode_idx]
        
        # # Assert that the entire window is within the same episode
        # assert history_start >= episode_start, (
        #     f"History start index {history_start} is before episode start {episode_start}"
        # )
        # assert future_end <= episode_end, (
        #     f"Future end index {future_end} is after episode end {episode_end}"
        # )
        # # End of sanity check

        # Negative sampling
        episode_idx = np.searchsorted(self.episode_end_indices, anchor_idx, side='right')
        episode_start = self.episode_start_indices[episode_idx]
        episode_end = self.episode_end_indices[episode_idx]

        negative_candidates = []
        future_not_enough = False

        if self.neg_sampling == 'future_as_neg':
            neg_start = anchor_idx + 25
            if neg_start + self.n_neg >= episode_end:
                future_not_enough = True
            else:
                negative_candidates = np.arange(neg_start, episode_end)
        if self.neg_sampling == 'other_ep_as_neg' or future_not_enough:
            for start, end in zip(self.episode_start_indices, self.episode_end_indices):
                if start != episode_start:
                    # print('sampling negative from start {} to end {}'.format(start, end))
                    # sampling heuristic
                    negative_candidates = np.concatenate((negative_candidates, np.arange(start, end)))

        # print('anchor_idx ', anchor_idx)
        # print('negative_candidates ', negative_candidates)
        negative_candidates = negative_candidates.astype(int)
        negative_t = np.random.choice(negative_candidates, size=self.n_neg, replace=False)
        # print("negative_candidates ", negative_candidates)
        # ### sanity check ###
        # for id in negative_t:
        #     # print('id', id)
        #     # print('neg state traj id {} t {}'.format(self.sample_indices[id][0], self.sample_indices[id][1]))
        #     neg_epi_idx = np.searchsorted(self.episode_end_indices, id, side='right')
        #     neg_epi_start = self.episode_start_indices[neg_epi_idx]
        #     if episode_start == neg_epi_start:
        #         raise ValueError('Negative samples are from the same trajectory')
        # # print('negative_t   ', negative_t)
        # # print('img future index ', np.arange(s_future_start, s_future_end))
        # ### end of sanity check ###
        neg_imgs = torch.tensor(np.moveaxis(self.imgs[negative_t],-1,1)/255, dtype=torch.float32)
        return {
            'history_imgs': history_imgs,       # Shape: (horizon +1, C, H, W)
            'history_actions': history_actions,     # Shape: (horizon, action_dim)
            'future_actions': future_actions,       # Shape: (horizon, action_dim)
            'future_imgs': future_imgs,         # Shape: (horizon, C, H, W)
            'negative_imgs': neg_imgs,          # Shape: (n_neg, horizon, C, H, W)
        }

    def get_normalizer(self, mode='limits', **kwargs):
        data = {
            'action': self.actions,
            'state': self.states,
        }
        normalizer = LinearNormalizer()
        normalizer.fit(data=data, last_n_dims=1, mode=mode, **kwargs)
        normalizer['image'] = get_image_range_normalizer()
        return normalizer


class ContrastiveDynamicsModelDataset(DynamicsModelDataset):
    def __init__(self, zarr_path, val_ratio=0.2, random=True, max_neg=5, n_pos=3, pos_index='123', neg_index=25, use_same_ep=True, use_other_ep=True, use_past_as_pos=False):
        super().__init__(zarr_path, val_ratio, random)
        self.max_neg = max_neg  # Maximum number of negative pairs to sample
        self.n_pos = n_pos  # Number of positive pairs to sample
        self.pos_index = pos_index  # Indices to sample positive pairs from
        self.neg_index = neg_index  # Maximum distance to sample negative pairs from
        self.use_same_ep = use_same_ep
        self.use_other_ep = use_other_ep
        self.use_past_as_pos = use_past_as_pos
        assert(self.use_same_ep or self.use_other_ep)
        # self.episode_ends = self.episode_ends - 1

    def __getitem__(self, idx):
        """
        Fetch the data for a given index `idx`, as well as dynamically sampled positive
        and negative pairs for contrastive learning.
        """
        # Get the regular data point (o_t, a_t, o_{t+1}) for the current index
        data = super().__getitem__(idx)  # This will always return the dict from the base class

        # Find the current episode index based on episode_ends (adjusting for 1-indexing)
        # print('====================================')
        # print('idx', idx)
        # print("episode_ends ", self.episode_ends)
        current_episode_end_idx = np.searchsorted(self.episode_ends, idx, side='right')
        # print('current_episode_end_idx ', current_episode_end_idx)
        episode_start = self.episode_ends[current_episode_end_idx - 1] if current_episode_end_idx > 0 else 0
        episode_end = self.episode_ends[current_episode_end_idx]

        # Positive Pair Sampling
        if self.pos_index == 123:
            positive_candidates = [idx - 3, idx - 2, idx - 1, idx + 1, idx + 2, idx + 3]
            if not self.use_past_as_pos:
                positive_candidates = [idx + 1, idx + 2, idx + 3]
        elif self.pos_index == 135:
            positive_candidates = [idx - 5, idx - 3, idx - 1, idx + 1, idx + 3, idx + 5]
            if not self.use_past_as_pos:
                positive_candidates = [idx + 1, idx + 3, idx + 5]
        if self.pos_index == 12:
            positive_candidates = [idx - 2, idx - 1, idx + 1, idx + 2,]
            if not self.use_past_as_pos:
                positive_candidates = [idx + 1, idx + 2]
        positive_candidates = [p for p in positive_candidates if episode_start <= p < episode_end]

        if len(positive_candidates) < self.n_pos:
            positive_candidates.extend([idx] * (self.n_pos - len(positive_candidates)))
            
        pos_indices = np.random.choice(positive_candidates, size=min(self.n_pos, len(positive_candidates)), replace=False)

        # Negative Pair Sampling
        neg_candidates_same_ep = list(range(episode_start, max(episode_start, idx - self.neg_index))) + \
                                 list(range(min(idx + self.neg_index, episode_end), episode_end))

        valid_episode_ends = self.episode_ends[self.train_mask]
        neg_candidates_other_ep = [i for ep_start, ep_end in zip(valid_episode_ends[:-1], valid_episode_ends[1:])
                                   if ep_start != episode_start  # Ensure it's from a different episode
                                   for i in range(ep_start, ep_end)]
        neg_candidates = []
        if self.use_same_ep and self.use_other_ep:
            neg_candidates = neg_candidates_same_ep + neg_candidates_other_ep
        elif self.use_same_ep and not self.use_other_ep:
            neg_candidates = neg_candidates_same_ep
        elif not self.use_same_ep and self.use_other_ep:
            neg_candidates = neg_candidates_other_ep
            
        neg_indices = np.random.choice(neg_candidates, size=min(self.max_neg, len(neg_candidates)), replace=False)
        # if idx == 328:
        # if idx in [0,172,173,171,170,169,174]:
        # # if True:
        #     print('idx', idx)
        #     print("episode_start ", episode_start)
        #     print("episode_end ", episode_end)
        #     print('pos_indices', pos_indices)
        #     print('neg_indices', neg_indices)
        #     diff = np.abs(np.array(neg_indices) - idx)
        #     if np.any(diff < 3):
        #         raise ValueError('Negative samples are too close to the current index')
            # print('neg_candidates_same_ep', neg_candidates_same_ep)
        # Get positive and negative data points
        pos_data = []
        for pos_idx in pos_indices:
            pos_example = super().__getitem__(pos_idx)
            del pos_example['action']
            del pos_example['o_t1']
            pos_data.append(pos_example)

        neg_data = []
        for neg_idx in neg_indices:
            neg_example = super().__getitem__(neg_idx)
            del neg_example['action']
            del neg_example['o_t1']
            neg_data.append(neg_example)

        return data, pos_data, neg_data


def test():
    import os
    # Run the test
    # test_dataset()

    zarr_path = os.path.expanduser('data/pusht/pusht_playdata_augment_v1')
    dataset = MultiStepImageDynamicsModelDataset(zarr_path, val_ratio=0.2, horizon=8)
    train_loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=1)
    # val_dataset = dataset.get_validation_dataset()
    # val_dataloader = DataLoader(val_dataset, batch_size=4, shuffle=True, num_workers=4)
    for batch_idx, batch in enumerate(train_loader):
        # s_t, actions_seq, s_future_seq = batch
        # history_imgs = batch['history_imgs']
        
        # print('min ', history_imgs.min())
        # print('max ', history_imgs.max())
        # print(s_t.shape, actions_seq.shape, s_future_seq.shape)
        x = batch_idx
        # if batch_idx == 200:
        #     break
    
    # for batch_idx, data in enumerate(val_dataloader):
    #     if batch_idx == 1:
    #         break
    
    # from matplotlib import pyplot as plt
    # normalizer = dataset.get_normalizer()
    # nactions = normalizer['action'].normalize(dataset.actions)
    # diff = np.diff(nactions, axis=0)
    # dists = np.linalg.norm(np.diff(nactions, axis=0), axis=-1)


if __name__ == "__main__":
    test()
