import cv2
import random
import numpy as np
from torch.utils.data import Dataset, DataLoader
import os
from torchvision import transforms

# load the yaw, pitch angles for the street-view images and yaw angles for the aerial view
import scipy.io as sio
from PIL import Image

class DiffSample(Dataset):
    def __init__(self, root, datalist, mask_root):
        self.img_root = root
        self.mask_root = mask_root
        self.data_list = datalist
        
        print('InputData::__init__: load %s' % self.data_list)
        self.__cur_id = 0  # for training
        self.id_list = []
        self.id_idx_list = []
        
        with open(self.data_list, 'r') as file:
            idx = 0
            for line in file:
                data = line.split(', ')
                
                gt = self.img_root + "/gt/"+ data[0]
                sample = self.img_root + "/gen/"+ data[1][:-1]
                # mask = self.mask_root + "/" + data[2][:-1]
                
                self.id_list.append([gt, sample])#, mask])
                self.id_idx_list.append(idx)
                idx += 1
                
        self.totensor = transforms.Compose([
            transforms.ToTensor(),
        ])
                
        self.mask_transform = transforms.Compose([
            transforms.Resize(size=[128, 512]),
            transforms.ToTensor(),
        ])
    
    def __len__(self):
        return len(self.id_list)

    def get_file_list(self):
        return self.id_list
    
    def __getitem__(self, idx):
        
        path_gt = self.id_list[idx][0]
        with Image.open(path_gt, 'r') as GTImg:
            img_gt = GTImg.convert('RGB')
            img_gt_resize = self.totensor(img_gt)
                
        path_sample = self.id_list[idx][1]
        with Image.open(path_sample, 'r') as SampleImg:
            img_sample = SampleImg.convert('RGB')
            img_sample_resize = self.totensor(img_sample)
        
        # path_mask = self.id_list[idx][2]
        # with Image.open(path_mask, 'r') as Mask:
        #     mask = Mask.convert('L')
        #     mask_resize = self.mask_transform(mask)


        return img_gt_resize, img_sample_resize, path_gt, path_sample, 
        
        
