import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os


class make_3channels(object):
    """Convert single channel tensor to 2 channel tensor"""

    def __call__(self, image):
        # print("Image shape:", image.shape)
        if image.shape[0] == 1:
            image = torch.cat([image, image, image], dim=0)
        return image


class ignore_background_label(object):
    """Convert the background 0 label to -1 tensor"""

    def __call__(self, segm):
        # to tensor, -1 to 149
        # segm = torch.from_numpy(np.array(segm)).long() - 1
        # return segm

        # print(segm.shape)
        # print(segm)
        # segm1 = segm.cpu().detach().numpy()
        for i in range(segm.shape[1]):
            for j in range(segm.shape[2]):
                if segm[0][i][j] == 255:
                    segm[0][i][j] = 19

        # segm = torch.from_numpy(segm1)

        # print(segm)
        # exit()
        return segm

def load_img(params, folder):
  # an iter variable
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename))
        images.append(img)
    return images


class trainset(Dataset):
    def __init__(self, params, transform=None, root_train=None, root_train_label=None, transform_label=None):
        self.train_img = load_img(params, root_train)
        self.transform = transform
        self.transform_label = transform_label
        self.train_label_img = load_img(params, root_train_label)

    def __len__(self):
        return len(self.train_img)

    def __getitem__(self, index):
        img = self.transform(self.train_img[index])
        label = self.transform_label(self.train_label_img[index])
        return img, label