import torch
import random
from torch.utils.data.sampler import BatchSampler, SubsetRandomSampler
from torch.utils.data import Dataset, DataLoader
from skimage.util.shape import view_as_windows
import os
import numpy as np

def reparameterize(mu, logsigma):
    std = torch.exp(0.5*logsigma)
    eps = torch.randn_like(std)
    return mu + eps*std


def obs_extract(obs):
    obs = np.transpose(obs['rgb'], (0,3,1,2))
    return torch.from_numpy(obs)


def count_step(i_update, i_env, i_step, num_envs, num_steps):
    step = i_update * (num_steps *  num_envs) + i_env * num_steps + i_step
    return step


# for representation learning
class ExpDataset(Dataset):
    def __init__(self, file_dir, game, num_splitted, transform,class_label=False):
        super(ExpDataset, self).__init__()
        self.file_dir = file_dir
        self.files = [f for f in os.listdir(file_dir)]

        # self.files = [f for f in os.listdir(file_dir) if game in f]
        self.num_splitted = num_splitted
        self.data = []
        self.progress = 0
        self.transform = transform
        self.class_label = class_label
        self.loadnext()


    def __len__(self):
        assert(len(self.data) > 0)
        # image_length = []
        # for i,f in enumerate(self.data):
        #     image_length.append(len(f))
        # return min(image_length)
        if self.class_label:
            return self.data[0][0].shape[0]
        else:
            return self.data[0][0].shape[0]

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if self.class_label:
            img_list = []
            label_list = []
            for d,label in self.data:
                if idx >= len(d):
                    idx = random.randrange(len(d))
                img_list.append(self.transform(d[idx])) 
                label_list.append(torch.from_numpy(label[idx]))
            imgs = torch.stack(img_list)
            labels = torch.stack(label_list)
            # imgs = torch.stack([self.transform(d[idx]) for d,_ in self.data])
            # labels = torch.stack([torch.from_numpy(label[idx]) for _,label in self.data])
            return (imgs, labels)
        else:
            img_list = []
            for d in self.data:
                if idx >= len(d):
                    idx = random.randrange(len(d))
                img_list.append(self.transform(d[idx])) 
            imgs = torch.stack(img_list)
            return imgs
            return torch.stack([self.transform(d[idx]) for d in self.data])

    def loadnext(self):
        self.data = []
        for file in self.files:
            frames = np.load(os.path.join(self.file_dir, file, '%d.npz' % (self.progress)))['obs']
            if self.class_label:
                labels = np.load(os.path.join(self.file_dir, file, '%d.npz' % (self.progress)))['labels']
                frames = (frames,labels)
            self.data.append(frames)
        # self.len_images = []
        # for i,f in enumerate(self.data):
        #     print(i,len(f))
        #     self.len_images.append(len(f))
        self.progress = (self.progress + 1) % self.num_splitted


# referred from https://github.com/MishaLaskin/curl
def random_crop(imgs, output_size):
    """
    Vectorized way to do random crop using sliding windows
    and picking out random ones

    args:
        imgs, batch images with shape (B,C,H,W)
    """
    # batch size
    n = imgs.shape[0]
    img_size = imgs.shape[-1]
    crop_max = img_size - output_size
    imgs = np.transpose(imgs, (0, 2, 3, 1))
    w1 = np.random.randint(0, crop_max, n)
    h1 = np.random.randint(0, crop_max, n)
    # creates all sliding windows combinations of size (output_size)
    windows = view_as_windows(
        imgs, (1, output_size, output_size, 1))[..., 0,:,:, 0]
    # selects a random window for each batch element
    cropped_imgs = windows[np.arange(n), w1, h1]
    return cropped_imgs