from collections import namedtuple
import numpy as np
import torch

Batch = namedtuple('Batch', 'trajectories returns costs')

class SegmentDataset(torch.utils.data.Dataset):
    def __init__(self, obs, action, task_id, n_tasks=20):
        self.obs = np.stack(obs, axis=0).astype(np.float32)
        self.action = np.stack(action, axis=0).astype(np.float32)
        self.task_id = np.ones(self.obs.shape[:2], dtype=np.float32) * task_id
        self.cond_input = np.eye(n_tasks, dtype=np.float32)[self.task_id.astype(np.int32)]
        self.observation_dim = self.obs.shape[-1]
        self.action_dim = self.action.shape[-1]
        self.n_task = n_tasks
    
    def __len__(self):
        return len(self.obs)
    
    def __getitem__(self, idx, eps=1e-4):
        observations = self.obs[idx]
        actions = self.action[idx]
        
        conditions = self.cond_input[idx][0]
        trajectories = np.concatenate([actions, observations], axis=-1)
        
        batch = Batch(trajectories, conditions)
        
        return batch

class TrajectoryDataset(torch.utils.data.Dataset):
    def __init__(self):
        # self.obs, self.action, self.task_id, self.cond_input = [], [], [], []
        self.obs, self.action, self.rtg_input, self.ctg_input = [], [], [], []
        self.data_init = False
    
    def add_data(self, obs, action, rtg, ctg, permutation=True):
        if not self.data_init:
            # self.obs = obs.reshape(-1, 4, obs.shape[-1]).astype(np.float32)
            # self.action = action.reshape(-1, 4, action.shape[-1]).astype(np.float32)
            self.obs = np.stack(obs, axis=0).astype(np.float32)
            self.action = np.stack(action, axis=0).astype(np.float32)
            self.rtg = np.stack(rtg, axis=0).astype(np.float32)
            self.ctg = np.stack(ctg, axis=0).astype(np.float32)
            self.data_init = True
        else:
            # added_obs = obs.reshape(-1, 4, obs.shape[-1]).astype(np.float32)
            # added_action = action.reshape(-1, 4, action.shape[-1]).astype(np.float32)
            # added_task_id = np.ones(added_obs.shape[:-1], dtype=np.float32) * task_id
            added_obs = np.stack(obs, axis=0).astype(np.float32)
            added_action = np.stack(action, axis=0).astype(np.float32)
            added_rtg = np.stack(rtg, axis=0).astype(np.float32)
            added_ctg = np.stack(ctg, axis=0).astype(np.float32)
            self.obs = np.concatenate([self.obs, added_obs], axis=0)
            self.action = np.concatenate([self.action, added_action], axis=0)
            self.rtg = np.concatenate([self.rtg, added_rtg], axis=0)
            self.ctg = np.concatenate([self.ctg, added_ctg], axis=0)
            if permutation:
                permutation_indices = np.random.permutation(self.obs.shape[0])
                self.obs, self.action, self.rtg, self.ctg = self.obs[permutation_indices], \
                    self.action[permutation_indices], self.rtg[permutation_indices], self.ctg[permutation_indices]

    def __len__(self):
        if not self.data_init:
            assert 0
        return len(self.obs)

    def __getitem__(self, idx, eps=1e-4):
        observations = self.obs[idx]
        actions = self.action[idx]
        rtgs = self.rtg[idx]
        ctgs = self.ctg[idx]
        trajectories = np.concatenate([actions, observations], axis=-1)
        
        batch = Batch(trajectories, rtgs, ctgs)
        
        return batch