import torch
import numpy as np
from PIL import Image
from skimage.transform import resize
from torch.utils.data import Dataset
from torchvision import transforms

img_size = 32

transform = transforms.Compose([transforms.Resize(img_size), transforms.ToTensor()])


def gray_loader(path):
    return Image.open(path).convert('L')


def rgb_loader(path):
    return Image.open(path).convert('RGB')


class datareader_standard(Dataset):
    def __init__(self, file_txt, data_root, transform=transform, target_transform=None, loader=rgb_loader, choice=1, offset=None):

        self.imgs = []
        self.labels = []
        self.offset=offset

        fileobj = open(data_root + file_txt, 'r')
        for line in fileobj:
            data = line.split()
            self.imgs.append(data_root + data[0])
            if self.offset!= None:
                self.labels.append(int(data[choice])+offset)
            else:
                self.labels.append([int(i) for i in data[choice]])
        self.transform = transform
        self.target_transform = target_transform
        self.loader = loader

    def __getitem__(self, index):
        img_path = self.imgs[index]
        img = self.loader(img_path)
        target = self.labels[index]

        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)