import torch
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset
import os
import os.path
import torchvision

_factory = {'CL': 'Clipart', 'PR': 'Product', 'RW': 'Real_World', 'AR': 'Art'}

class OfficeHome(object):
    def __init__(self, root, subset='AR', labels=None, length=None):
        self.root = root
        self.labels = labels
        self.root = root
        self.subset = _factory[subset]
        self.length = length
        self.txt_dir = os.path.join(root, 'New_' + self.subset + '.txt')
        image_list = open(self.txt_dir).readlines()
        self.train = self._process_dir(image_list)
        self.num_classes = len(set([fname_label[1] for fname_label in self.train]))

    def _process_dir(self, image_list):
        if self.labels:
            # len_ = self.length
            len_ = len(image_list)
            images = [(image_list[i].strip(), self.labels[i, :]) for i in range(len_)]
        else:
            if len(image_list[0].split()) > 2:
                images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
            else:
                images = [(val.split()[0], int(val.split()[1])) for val in image_list]

        return images



def make_dataset(image_list, labels):
    if labels:
      len_ = len(image_list)
      images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
      if len(image_list[0].split()) > 2:
        images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
      else:
        images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


def rgb_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def l_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')

# class ImageList(Dataset):
#     def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
#         imgs = make_dataset(image_list, labels)
#         if len(imgs) == 0:
#             raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
#                                "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
#
#         self.imgs = imgs
#         self.transform = transform
#         self.target_transform = target_transform
#         if mode == 'RGB':
#             self.loader = rgb_loader
#         elif mode == 'L':
#             self.loader = l_loader
#
#     def __getitem__(self, index):
#         path, target = self.imgs[index]
#         img = self.loader(path)
#         if self.transform is not None:
#             img = self.transform(img)
#         if self.target_transform is not None:
#             target = self.target_transform(target)
#
#         return img, target
#
#     def __len__(self):
#         return len(self.imgs)
#
# class Office31(Dataset):
#     def __init__(self, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
#         imgs = make_dataset(image_list, labels)
#         if len(imgs) == 0:
#             raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
#                                "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
#
#         self.imgs = imgs
#         self.transform = transform
#         self.target_transform = target_transform
#         if mode == 'RGB':
#             self.loader = rgb_loader
#         elif mode == 'L':
#             self.loader = l_loader
#
#     def __getitem__(self, index):
#         path, target = self.imgs[index]
#         img = self.loader(path)
#         if self.transform is not None:
#             img = self.transform(img)
#         if self.target_transform is not None:
#             target = self.target_transform(target)
#
#         return img, target, index
#
#     def __len__(self):
#         return len(self.imgs)