import os
import io
import numpy as np
from PIL import Image
from imageio import imread 
import pickle
from torch.utils.data import Dataset


class BairSimpleDataset(Dataset):
    
    """Data Handler that loads robot pushing data."""

    def __init__(self, data_root='', train=True, seq_len=20, image_size=64):
        self.root_dir = data_root 
        if train:
            self.data_dir = '%s/processed_data/train' % self.root_dir
            self.ordered = False
        else:
            self.data_dir = '%s/processed_data/test' % self.root_dir
            self.ordered = True 
        if not os.path.exists(self.data_dir):
            raise ValueError('Data not found')
        self.dirs = []
        if False and os.path.exists(f'{self.data_dir}/data_dict.pkl'):
            self.dirs = pickle.load(open(f'{self.data_dir}/data_dict.pkl', 'rb'))
        else:
            for d1 in os.listdir(self.data_dir):
                if os.path.isdir('%s/%s' % (self.data_dir, d1)):
                    for d2 in os.listdir('%s/%s' % (self.data_dir, d1)):
                        self.dirs.append('%s/%s/%s' % (self.data_dir, d1, d2))
            pickle.dump(self.dirs, open(f'{self.data_dir}/data_dict.pkl', 'wb'))
        self.seq_len = seq_len
        self.image_size = image_size
          
    def __len__(self):
        return len(self.dirs)

    def __getitem__(self, index):
        # self.set_seed(index)
        # return self.get_seq()
        d = self.dirs[index]
        image_seq = []
        # random_start = np.random.randint(0, 30 - self.seq_len)
        random_start = 0
        for i in range(random_start, random_start + self.seq_len):
            fname = '%s/%d.png' % (d, i)
            im = imread(fname).reshape(1, 64, 64, 3)
            image_seq.append(im)
        image_seq = np.concatenate(image_seq, axis=0)
        actions = np.load('%s/action.npy' % d) if os.path.exists('%s/action.npy' % d) else np.zeros((self.seq_len, 4))
        states = np.load('%s/pos.npy' % d) if os.path.exists('%s/pos.npy' % d) else np.zeros((self.seq_len, 5))
        actions = actions[random_start + 1:random_start + self.seq_len]
        states = states[random_start:random_start + self.seq_len]
        return image_seq, actions, states
