import os
import torch
import pickle
import random
import numpy as np
import gymnasium as gym
from datetime import datetime
from sklearn.cluster import MiniBatchKMeans
import collections

from collections import defaultdict
from torch.utils.data import Dataset
    
def extract_discrete_id_to_data_id_map(discrete_goals, dones, last_valid_traj):
    discrete_goal_to_data_idx = defaultdict(list)
    gm = 0
    for i, d_g in enumerate(discrete_goals):

        discrete_goal_to_data_idx[d_g].append(i)
        gm += 1
        
        if (i + 1) % 200000 == 0:
            print('Goals mapped:', gm)
        
        if i == last_valid_traj:
            break
    
    for dg, data_idxes in discrete_goal_to_data_idx.items():
        discrete_goal_to_data_idx[dg] = np.array(data_idxes)

    print('Total goals mapped:', gm)
    return discrete_goal_to_data_idx

def extract_done_markers(dones, episode_ids):
    """Given a per-timestep dones vector, return starts, ends, and lengths of trajs."""

    (ends,) = np.where(dones)
    return ends[ episode_ids[ : ends[-1] + 1 ] ], np.where(1 - dones[: ends[-1] + 1])[0]

def discount_cumsum(x, h, gamma):
    disc_cumsum = np.zeros_like(x)
    disc_cumsum[-1] = x[-1]
    for t in reversed(range(h - 1)):
        disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1]
    return disc_cumsum

class VisionKMeansEpisodicTrajectoryDataset(Dataset):
    def __init__(self, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        print(self.observations.shape)

        self.proprio = self.dataset['observations']['proprio']
        # self.proprio = self.dataset['observations']['proprio']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1           #length = number of actions taken in an episode + 1
        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]

        good_idxes = self.lengths > context_len
        print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions')
        self.starts = self.starts[good_idxes]           #starts will only contain indices of episodes where number of states > context_len
        self.lengths = self.lengths[good_idxes]
        
        self.num_trajectories = len(self.starts)

        if augment_data:    
            start_time = datetime.now().replace(microsecond=0)
            print('starting kmeans ... ')
            if len(self.observations.shape) > 2:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations.reshape(self.observations.shape[0], -1) / 255.0 - 0.5)
            else:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations)
            time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
            print('kmeans done! time taken :', time_elapsed)

            self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
            self.achieved_discrete_goals = kmeans.labels_
            kmeans = None
        
        self.goal_dim = self.achieved_goals.shape[-1]
        self.context_len = context_len
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        idx = idx % self.num_trajectories
        traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
        traj_start_i = self.starts[idx]
        assert self.ends[traj_start_i] == traj_start_i + traj_len

        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            correct = False
            while not correct:
                si = traj_start_i + np.random.randint(0, traj_len)          #si can be traj_start_i + [0, T - 1]
                gi = np.random.randint(si, traj_start_i + traj_len) + 1     #gi can be traj_start_i + 1 + [si + 1, T]     
                dummy_discrete_goal = self.achieved_discrete_goals[ gi ]
                nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])
                nearby_goal_idx_ends = self.ends[nearby_goal_idx]
                if (gi-si) + (nearby_goal_idx_ends - nearby_goal_idx) + 1 > self.context_len:
                    correct = True
                
            if gi - si < self.context_len:
                goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx + self.context_len - (gi - si), nearby_goal_idx_ends + 1) ]).view(1, -1)
                state = torch.tensor( np.concatenate( [ self.observations[si: gi], self.observations[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
                proprio = torch.tensor( np.concatenate( [ self.proprio[si: gi], self.proprio[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
                action = torch.tensor( np.concatenate( [ self.actions[si: gi], self.actions[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
            else:
                goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, nearby_goal_idx_ends + 1) ]).view(1, -1)
                state = torch.tensor(self.observations[si: si + self.context_len])
                proprio = torch.tensor(self.proprio[si: si + self.context_len])
                action = torch.tensor(self.actions[si: si + self.context_len])

        else:
            si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
            gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1)      #gi can be [si+C, traj_start_i+T]

            goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1)
            state = torch.tensor(self.observations[si: si + self.context_len])
            proprio = torch.tensor(self.proprio[si: si + self.context_len])
            action = torch.tensor(self.actions[si: si + self.context_len])
        
        # return proprio, goal, action
        return state, proprio, goal, action
    
class VisionMaxEpisodicTrajectoryDataset(Dataset):
    def __init__(self, env, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        print(self.observations.shape)

        self.proprio = self.dataset['observations']['proprio'][:, 2:]
        # self.proprio = self.dataset['observations']['proprio']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1           #length = number of actions taken in an episode + 1
        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]

        good_idxes = self.lengths > context_len
        print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions')
        self.starts = self.starts[good_idxes]           #starts will only contain indices of episodes where number of states > context_len
        self.lengths = self.lengths[good_idxes]
        
        self.num_trajectories = len(self.starts)

        use_timeouts = False
        if hasattr(env.spec, 'max_episode_steps'):
            use_timeouts = True
            print("env has max episode!!", env.spec.max_episode_steps)

        data_ = collections.defaultdict(list)
        episode_step = 0
        self.trajectories = []
        for i in range(dataset_size):
            if use_timeouts:
                #print("env has max episode!!")
                final_timestep = (episode_step == env.spec.max_episode_steps-1)
            else:
                final_timestep = (episode_step == 1000-1)
            data_["observations"].append(self.observations[i])
            data_["proprio"].append(self.proprio[i])
            data_["achieved_goal"].append(self.achieved_goals[i])
            data_["actions"].append(self.actions[i])
            data_["rewards"].append(self.rewards[i])
            data_["terminals"].append(self.dataset['terminations'][i])
            if final_timestep:
                episode_step = 0
                episode_data = {}
                for k in data_:
                    episode_data[k] = np.array(data_[k])
                self.trajectories.append(episode_data)
                data_ = collections.defaultdict(list)
            episode_step += 1
        
        self.goal_dim = self.achieved_goals.shape[-1]
        self.context_len = context_len
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            correct = False
            while not correct:
                si = traj_start_i + np.random.randint(0, traj_len)          #si can be traj_start_i + [0, T - 1]
                gi = np.random.randint(si, traj_start_i + traj_len) + 1     #gi can be traj_start_i + 1 + [si + 1, T]     
                dummy_discrete_goal = self.achieved_discrete_goals[ gi ]
                nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])
                nearby_goal_idx_ends = self.ends[nearby_goal_idx]
                if (gi-si) + (nearby_goal_idx_ends - nearby_goal_idx) + 1 > self.context_len:
                    correct = True
                
            if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
                max_num_trajectories = len(self.trajectories)
                traj = self.trajectories[idx%max_num_trajectories]
                traj_len = traj["observations"].shape[0]

                # sample random index to slice trajectory
                si = np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
                gi = np.random.randint(si + self.context_len, traj_len+1) -1     #gi can be [si+C, traj_start_i+T]

                state = torch.from_numpy(
                    traj["observations"][si : si + self.context_len]
                )
                proprio = torch.from_numpy(
                    traj["proprio"][si : si + self.context_len]
                )
                action = torch.from_numpy(
                    traj["actions"][si : si + self.context_len]
                )
                goal = torch.tensor( traj["achieved_goal"][ gi ]).view(1, -1)
                #reward = torch.from_numpy(traj["rewards"][si : si + self.context_len])
                timesteps = torch.arange(
                    start=si, end=si + self.context_len, step=1
                )
                traj_mask = torch.ones(self.context_len, dtype=torch.long)
        else:
            idx = idx % self.num_trajectories
            traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
            traj_start_i = self.starts[idx]
            assert self.ends[traj_start_i] == traj_start_i + traj_len
            si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
            gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1)      #gi can be [si+C, traj_start_i+T]

            goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1)
            state = torch.tensor(self.observations[si: si + self.context_len])
            proprio = torch.tensor(self.proprio[si: si + self.context_len])
            action = torch.tensor(self.actions[si: si + self.context_len])
            timesteps = torch.arange(
                start=si, end=si + self.context_len, step=1
            )
            traj_mask = torch.ones(self.context_len, dtype=torch.long)
        # return proprio, goal, action
        return timesteps, state, proprio, goal, action, traj_mask
  
class VisionKMeansEpisodicDataset(Dataset):
    def __init__(self, dataset_name, dataset_size, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        # self.proprio = self.dataset['observations']['proprio'][:, 2:]
        self.proprio = self.dataset['observations']['proprio']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1
        self.num_trajectories = len(self.starts)

        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
        
        if augment_data:    
            start_time = datetime.now().replace(microsecond=0)
            print('starting kmeans ... ')
            if len(self.observations.shape) > 2:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations.reshape(self.observations.shape[0], -1) / 255.0 - 0.5)
            else:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations)
            time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
            print('kmeans done! time taken :', time_elapsed)

            self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
            self.achieved_discrete_goals = kmeans.labels_
            kmeans = None

        self.goal_dim = self.achieved_goals.shape[-1]
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        idx = idx % self.num_trajectories
        traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
        traj_start_i = self.starts[idx]
        assert self.ends[traj_start_i] == traj_start_i + traj_len

        si = np.random.randint(0, traj_len)                     #si can be [0, T-1]  

        state = torch.tensor(self.observations[traj_start_i + si])
        proprio = torch.tensor(self.proprio[traj_start_i + si])
        action = torch.tensor(self.actions[traj_start_i + si])
        
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            dummy_discrete_goal = self.achieved_discrete_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ]
            nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])            
            goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, self.ends[nearby_goal_idx] + 1) ])
        else:
            goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ])

        #return proprio, goal, action
        return state, proprio, goal, action
    
class VisionMaxEpisodicDataset(Dataset):
    def __init__(self, env, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        # self.proprio = self.dataset['observations']['proprio'][:, 2:]
        self.proprio = self.dataset['observations']['proprio']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1
        self.num_trajectories = len(self.starts)

        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
        
        use_timeouts = False
        if hasattr(env.spec, 'max_episode_steps'):
            use_timeouts = True
            print("env has max episode!!", env.spec.max_episode_steps)

        data_ = collections.defaultdict(list)
        episode_step = 0
        self.trajectories = []
        for i in range(dataset_size):
            if use_timeouts:
                #print("env has max episode!!")
                final_timestep = (episode_step == env.spec.max_episode_steps-1)
            else:
                final_timestep = (episode_step == 1000-1)
            data_["observations"].append(self.observations[i])
            data_["proprio"].append(self.proprio[i])
            data_["achieved_goal"].append(self.achieved_goals[i])
            data_["actions"].append(self.actions[i])
            data_["rewards"].append(self.rewards[i])
            data_["terminals"].append(self.dataset['terminations'][i])
            if final_timestep:
                episode_step = 0
                episode_data = {}
                for k in data_:
                    episode_data[k] = np.array(data_[k])
                self.trajectories.append(episode_data)
                data_ = collections.defaultdict(list)
            episode_step += 1

        self.goal_dim = self.achieved_goals.shape[-1]
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        idx = idx % self.num_trajectories
        traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
        traj_start_i = self.starts[idx]
        assert self.ends[traj_start_i] == traj_start_i + traj_len

        si = np.random.randint(0, traj_len)                     #si can be [0, T-1]  

        state = torch.tensor(self.observations[traj_start_i + si])
        proprio = torch.tensor(self.proprio[traj_start_i + si])
        action = torch.tensor(self.actions[traj_start_i + si])
        
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            dummy_discrete_goal = self.achieved_discrete_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ]
            nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])            
            goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, self.ends[nearby_goal_idx] + 1) ])
        else:
            goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ])

        #return proprio, goal, action
        return state, proprio, goal, action
    
class MaxEpisodicTrajectoryDataset(Dataset):
    def __init__(self, env, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()
        if vision:
            dataset_name = "v-" + dataset_name

            path = '/cvlabdata1/leixing/data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = '/cvlabdata1/leixing/data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.observations = self.dataset['observations']['observation'] 
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']
        self.rewards = self.dataset['reward']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1           #length = number of actions taken in an episode + 1
        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]

        good_idxes = self.lengths > context_len
        print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions')
        self.starts = self.starts[good_idxes]           #starts will only contain indices of episodes where number of states > context_len
        self.lengths = self.lengths[good_idxes]
        
        self.num_trajectories = len(self.starts)

        
        # reward scale
        if env in ["maze2d"]:
            self.scale = 100
        elif env in ["antmaze"]:
            self.scale = 1

        use_timeouts = False
        if hasattr(env.spec, 'max_episode_steps'):
            use_timeouts = True
            print("env has max episode!!", env.spec.max_episode_steps)

        data_ = collections.defaultdict(list)
        episode_step = 0
        self.trajectories = []
        for i in range(dataset_size):
            if use_timeouts:
                #print("env has max episode!!")
                final_timestep = (episode_step == env.spec.max_episode_steps-1)
            else:
                final_timestep = (episode_step == 1000-1)
            data_["observations"].append(self.observations[i])
            data_["achieved_goal"].append(self.achieved_goals[i])
            data_["actions"].append(self.actions[i])
            data_["rewards"].append(self.rewards[i])
            data_["terminals"].append(self.dataset['terminations'][i])
            if final_timestep:
                episode_step = 0
                episode_data = {}
                for k in data_:
                    episode_data[k] = np.array(data_[k])
                self.trajectories.append(episode_data)
                data_ = collections.defaultdict(list)
            episode_step += 1
               
        self.state_dim = self.observations.shape[-1]
        self.state_dtype = self.observations.dtype
        self.act_dim = self.actions.shape[-1]
        self.act_dtype = self.actions.dtype
        self.goal_dim = self.achieved_goals.shape[-1]
        self.context_len = context_len
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

        print("num_trajectories:", self.num_trajectories)
        print("self.observations.shape:", self.observations.shape)
        print("actions.shape:", self.actions.shape)
        print("self.rewards.shape:", self.rewards.shape)
    
    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            max_num_trajectories = len(self.trajectories)
            traj = self.trajectories[idx%max_num_trajectories]
            traj_len = traj["observations"].shape[0]

            # sample random index to slice trajectory
            si = np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
            gi = np.random.randint(si + self.context_len, traj_len+1) -1     #gi can be [si+C, traj_start_i+T]

            state = torch.from_numpy(
                traj["observations"][si : si + self.context_len]
            )
            action = torch.from_numpy(
                traj["actions"][si : si + self.context_len]
            )
            goal = torch.tensor( traj["achieved_goal"][ gi ]).view(1, -1)
            #reward = torch.from_numpy(traj["rewards"][si : si + self.context_len])
            timesteps = torch.arange(
                start=si, end=si + self.context_len, step=1
            )
            traj_mask = torch.ones(self.context_len, dtype=torch.long)
        else:
            idx = idx % self.num_trajectories
            traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
            traj_start_i = self.starts[idx]
            assert self.ends[traj_start_i] == traj_start_i + traj_len
            si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
            gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1)      #gi can be [si+C, traj_start_i+T]

            goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1)
            state = torch.tensor(self.observations[si: si + self.context_len])
            action = torch.tensor(self.actions[si: si + self.context_len])
            #reward = torch.tensor(self.rewards[si: si + self.context_len])
            timesteps = torch.arange(
                start=si, end=si + self.context_len, step=1
            )
            traj_mask = torch.ones(self.context_len, dtype=torch.long)
        return timesteps, state, goal, action, traj_mask

class KMeansEpisodicTrajectoryDataset(Dataset):
    def __init__(self, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()
        if vision:
            dataset_name = dataset_name

            path = '/cvlabdata1/leixing/data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = '/cvlabdata1/leixing/data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.observations = self.dataset['observations']['observation']     
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']
        self.rewards = self.dataset['reward']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1           #length = number of actions taken in an episode + 1
        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]

        good_idxes = self.lengths > context_len
        print('Throwing away ', np.sum(self.lengths[~good_idxes] - 1), 'number of transitions')
        self.starts = self.starts[good_idxes]           #starts will only contain indices of episodes where number of states > context_len
        self.lengths = self.lengths[good_idxes]
        
        self.num_trajectories = len(self.starts)

        if augment_data:    
            start_time = datetime.now().replace(microsecond=0)
            print('starting kmeans ... ')
            if len(self.observations.shape) > 2:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations.reshape(self.observations.shape[0], -1) / 255.0 - 0.5)
            else:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations)
            time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
            print('kmeans done! time taken :', time_elapsed)

            self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
            self.achieved_discrete_goals = kmeans.labels_
            kmeans = None
         
        self.state_dim = self.observations.shape[-1]
        self.state_dtype = self.observations.dtype
        self.act_dim = self.actions.shape[-1]
        self.act_dtype = self.actions.dtype
        self.goal_dim = self.achieved_goals.shape[-1]
        self.context_len = context_len
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None
    
    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        idx = idx % self.num_trajectories
        traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
        traj_start_i = self.starts[idx]
        assert self.ends[traj_start_i] == traj_start_i + traj_len

        
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            correct = False
            while not correct:
                si = traj_start_i + np.random.randint(0, traj_len)          #si can be traj_start_i + [0, T - 1]
                gi = np.random.randint(si, traj_start_i + traj_len) + 1     #gi can be traj_start_i + 1 + [si + 1, T]     
                dummy_discrete_goal = self.achieved_discrete_goals[ gi ]
                nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])
                nearby_goal_idx_ends = self.ends[nearby_goal_idx]
                if (gi-si) + (nearby_goal_idx_ends - nearby_goal_idx) + 1 > self.context_len:
                    correct = True
                
            if gi - si < self.context_len:
                goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx + self.context_len - (gi - si), nearby_goal_idx_ends + 1) ]).view(1, -1)
                state = torch.tensor( np.concatenate( [ self.observations[si: gi], self.observations[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
                action = torch.tensor( np.concatenate( [ self.actions[si: gi], self.actions[nearby_goal_idx: nearby_goal_idx + self.context_len - (gi - si)] ] ) )
            else:
                goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, nearby_goal_idx_ends + 1) ]).view(1, -1)
                state = torch.tensor(self.observations[si: si + self.context_len])
                action = torch.tensor(self.actions[si: si + self.context_len])

        else:
            si = traj_start_i + np.random.randint(0, traj_len - self.context_len + 1)       #si can be traj_start_i + [0, T-C]  
            gi = np.random.randint(si + self.context_len, traj_start_i + traj_len + 1)      #gi can be [si+C, traj_start_i+T]

            goal = torch.tensor(self.achieved_goals[ gi ]).view(1, -1)
            state = torch.tensor(self.observations[si: si + self.context_len])
            action = torch.tensor(self.actions[si: si + self.context_len])
        
        return state, goal, action

class MaxEpisodicDataset(Dataset):
    def __init__(self, env, dataset_name, dataset_size, context_len, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = "v-" + dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']
        self.rewards = self.dataset['reward']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1
        self.num_trajectories = len(self.starts)

        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
        
        use_timeouts = False
        if hasattr(env.spec, 'max_episode_steps'):
            use_timeouts = True
            print("env has max episode!!", env.spec.max_episode_steps)

        data_ = collections.defaultdict(list)
        episode_step = 0
        self.trajectories = []
        for i in range(dataset_size):
            if use_timeouts:
                #print("env has max episode!!")
                final_timestep = (episode_step == env.spec.max_episode_steps-1)
            else:
                final_timestep = (episode_step == 1000-1)
            data_["observations"].append(self.observations[i])
            data_["achieved_goal"].append(self.achieved_goals[i])
            data_["actions"].append(self.actions[i])
            data_["rewards"].append(self.rewards[i])
            data_["terminals"].append(self.dataset['terminations'][i])
            if final_timestep:
                episode_step = 0
                episode_data = {}
                for k in data_:
                    episode_data[k] = np.array(data_[k])
                self.trajectories.append(episode_data)
                data_ = collections.defaultdict(list)
            episode_step += 1

        self.goal_dim = self.achieved_goals.shape[-1]
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            max_num_trajectories = len(self.trajectories)
            traj = self.trajectories[idx%max_num_trajectories]
            traj_len = traj["observations"].shape[0]

            # sample random index to slice trajectory
            si = np.random.randint(0, traj_len + 1)        
            gi = np.random.randint(si, traj_len+1) -1     

            state = torch.from_numpy(
                traj["observations"][si]
            )
            action = torch.from_numpy(
                traj["actions"][si]
            )
            goal = torch.tensor( traj["achieved_goal"][ gi ])
            #reward = torch.from_numpy(traj["rewards"][si : si + self.context_len])
        else:
            idx = idx % self.num_trajectories
            traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
            traj_start_i = self.starts[idx]
            assert self.ends[traj_start_i] == traj_start_i + traj_len

            si = np.random.randint(0, traj_len)                     #si can be [0, T-1]  

            state = torch.tensor(self.observations[traj_start_i + si])
            action = torch.tensor(self.actions[traj_start_i + si])
            goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ])

        return state, goal, action
    
class KMeansEpisodicDataset(Dataset):
    def __init__(self, dataset_name, dataset_size, augment_data, augment_prob, nclusters=40, vision = False):
        super().__init__()

        if vision:
            dataset_name = "v-" + dataset_name

            path = 'data/pixels/'+dataset_name+'-'+str(dataset_size)+'.pkl'
        else:
            dataset_name = dataset_name

            path = 'data/state/'+dataset_name+'-'+str(dataset_size)+'.pkl'
    
        with open(path, 'rb') as fp:
            self.dataset = pickle.load(fp)

        self.episode_ids = self.dataset['observations']['episode_id']
        self.observations = self.dataset['observations']['observation']
        self.achieved_goals = self.dataset['observations']['achieved_goal']
        self.actions = self.dataset['actions']
        self.rewards = self.dataset['reward']

        (ends,) = np.where(self.dataset['terminations'])
        self.starts = np.concatenate(([0], ends[:-1] + 1))
        self.lengths = ends - self.starts + 1
        self.num_trajectories = len(self.starts)

        self.ends = ends[ self.dataset['observations']['episode_id'][ : ends[-1] + 1 ] ]
        
        if augment_data:    
            start_time = datetime.now().replace(microsecond=0)
            print('starting kmeans ... ')
            if len(self.observations.shape) > 2:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations.reshape(self.observations.shape[0], -1) / 255.0 - 0.5)
            else:
                kmeans = MiniBatchKMeans(n_clusters=nclusters, n_init="auto").fit(self.observations)
            time_elapsed = str(datetime.now().replace(microsecond=0) - start_time)
            print('kmeans done! time taken :', time_elapsed)

            self.discrete_goal_to_data_idx = extract_discrete_id_to_data_id_map(kmeans.labels_, self.dataset['terminations'], self.ends[-1])
            self.achieved_discrete_goals = kmeans.labels_
            kmeans = None

        self.goal_dim = self.achieved_goals.shape[-1]
        self.augment_data = augment_data
        self.augment_prob = augment_prob
        self.dataset = None

    def __len__(self):
        return self.num_trajectories * 100
    
    def __getitem__(self, idx):
        '''
        Reminder: np.random.randint samples from the set [low, high)
        '''
        idx = idx % self.num_trajectories
        traj_len = self.lengths[idx] - 1                        #traj_len = T, traj_len is the number of actions taken in the trajectory
        traj_start_i = self.starts[idx]
        assert self.ends[traj_start_i] == traj_start_i + traj_len

        si = np.random.randint(0, traj_len)                     #si can be [0, T-1]  

        state = torch.tensor(self.observations[traj_start_i + si])
        action = torch.tensor(self.actions[traj_start_i + si])
        
        if self.augment_data and np.random.uniform(0, 1) <= self.augment_prob:
            dummy_discrete_goal = self.achieved_discrete_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ]
            nearby_goal_idx = np.random.choice(self.discrete_goal_to_data_idx[dummy_discrete_goal])            
            goal = torch.tensor(self.achieved_goals[ np.random.randint(nearby_goal_idx, self.ends[nearby_goal_idx] + 1) ])
        else:
            goal = torch.tensor(self.achieved_goals[ traj_start_i + np.random.randint(si, traj_len) + 1 ])

        return state, goal, action
