import os
import io
import numpy as np
from PIL import Image
from imageio import imread 
import pickle


class RobonetSimpleDataset(object):
    
    """Data Handler that loads robot pushing data."""

    def __init__(self, data_root='', train=True, seq_len=20, image_size=64):
        self.root_dir = data_root 
        if train:
            self.data_dir = '%s/processed_data/train' % self.root_dir
            self.ordered = False
        else:
            self.data_dir = '%s/processed_data/test' % self.root_dir
            self.ordered = True 
        if not os.path.exists(self.data_dir):
            raise ValueError('Data not found')
        self.dirs = []
        for d in os.listdir(self.data_dir):
            # 如果是文件夹
            if os.path.isdir('%s/%s' % (self.data_dir, d)):
                self.dirs.append('%s/%s' % (self.data_dir, d))
        sorted(self.dirs)
        self.seq_len = seq_len
        self.image_size = image_size
          
    def __len__(self):
        return len(self.dirs)

    def __getitem__(self, index):
        # self.set_seed(index)
        # return self.get_seq()
        d = self.dirs[index]
        image_seq = []
        # random_start = np.random.randint(0, 30 - self.seq_len)
        random_start = 0
        for i in range(self.seq_len):
            fname = '%s/%d.png' % (d, i)
            im = imread(fname).reshape(1, 64, 64, 3)
            image_seq.append(im)
        image_seq = np.concatenate(image_seq, axis=0)
        actions = np.load('%s/action.npy' % d)
        states = np.load('%s/state.npy' % d)
        actions = actions[random_start + 1:random_start + self.seq_len]
        states = states[random_start:random_start + self.seq_len]
        return image_seq, actions, states