import os
from PIL import Image
import pandas as pd
from torch.utils.data import Dataset
from torchvision import transforms


def image_loader(path):
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


default_transform = transforms.Compose([
    # 0-1 range
    transforms.ToTensor(),
    # normalization
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])


class FileListDataset(Dataset):
    def __init__(self, filelist, root_folder, transform=default_transform):
        super(FileListDataset, self).__init__()
        self.dataset = pd.read_csv(filelist, sep=' ', header=None)
        self.root_dir = root_folder
        self.transform = transform

    def __getitem__(self, index):
        item = self.dataset.iloc[index]
        image = image_loader(os.path.join(self.root_dir, item[0]))
        if self.transform is not None:
            image = self.transform(image)
        return image, item[1]

    def __len__(self):
        return len(self.dataset)
