#!/usr/bin/python
# encoding: utf-8
import os
from torch.utils.data import Dataset
from PIL import Image
import torch

class GTResDataset(Dataset):

	def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
		self.pairs = []
		for f in os.listdir(root_path):
			image_path = os.path.join(root_path, f)
			gt_path = os.path.join(gt_dir, f)
			if f.endswith(".jpg") or f.endswith(".png"):
				self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
		self.transform = transform
		self.transform_train = transform_train

	def __len__(self):
		return len(self.pairs)

	def __getitem__(self, index):
		from_path, to_path, _ = self.pairs[index]
		from_im = Image.open(from_path).convert('RGB')
		to_im = Image.open(to_path).convert('RGB')

		if self.transform:
			to_im = self.transform(to_im)
			from_im = self.transform(from_im)

		return from_im, to_im
