from __future__ import print_function, division
import torch
import numpy as np
import cv2


class ImageDataset(object):
    """
    Dataset for loading images at run-time
    """

    def __init__(self, img_list, width=128, if_flip=True, seed=1234):
        self.img_list = img_list
        self.width = width
        self._if_flip = if_flip
        self._rng = np.random.RandomState(seed)

    def __len__(self):
        return len(self.img_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        # flip horizontal
        eps = self._rng.rand()
        flip_horizontal = eps > 0.5

        # read rgb image
        img_name = self.img_list[idx]
        image = cv2.cvtColor(cv2.imread(img_name), cv2.COLOR_BGR2RGB)
        image = cv2.resize(image, (self.width, self.width))
        if flip_horizontal and self._if_flip:
            image = cv2.flip(image, 1)
        image = torch.tensor(image, dtype=torch.uint8).permute(2, 0, 1)
        return image
