#from __future__ import print_function, division

import torch
import numpy as np
import random
from PIL import Image
from torch.utils.data import Dataset
import os
import os.path

import cv2
import torchvision

def make_dataset(image_list, labels):
    if labels:
      len_ = len(image_list)
      images = [(image_list[i].strip(), labels[i, :]) for i in range(len_)]
    else:
      if len(image_list[0].split()) > 2:
        images = [(val.split()[0], np.array([int(la) for la in val.split()[1:]])) for val in image_list]
      else:
        images = [(val.split()[0], int(val.split()[1])) for val in image_list]
    return images


def rgb_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def l_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')

class ImageList(Dataset):
    def __init__(self, image_list, idx_to_domain, labels=None, transform=None, target_transform=None, mode='RGB', subsample=False):
        imgs = make_dataset(image_list, labels)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.idx_to_domain = idx_to_domain
        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader

        self.labels_to_idx = self.get_labels_to_idx(self.imgs)
        if subsample:
            self.imgs = self.subsample(self.imgs, self.labels_to_idx)
            self.labels_to_idx = self.get_labels_to_idx(self.imgs)

    def get_labels_to_idx(self, data):
        labels_to_idx = {}
        for idx, path in enumerate(data):
            label = path[1]
            if label not in labels_to_idx:
                labels_to_idx[label] = [idx]
            else:
                labels_to_idx[label].append(idx)
        return labels_to_idx 

    def subsample(self, data, labels_to_idx):
        np.random.seed(0)
        keep_idx = []
        num_classes = len(labels_to_idx)
        for label in sorted(labels_to_idx.keys()):
            if label < num_classes//2:
                keep_idx.extend(np.random.choice(labels_to_idx[label], int(0.3*len(labels_to_idx[label])), replace=False).tolist())
            else:
                keep_idx.extend(labels_to_idx[label])
        keep_idx = set(keep_idx)
        res = []
        for i in range(len(data)):
            if i in keep_idx:
                res.append(data[i])
        return res

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        domain_id = self.idx_to_domain[index]
        return img, target, domain_id

    def __len__(self):
        return len(self.imgs)

class ImageList_idx(Dataset):
    def __init__(self, image_list, idx_to_domain, labels=None, transform=None, target_transform=None, mode='RGB', subsample=False):
        imgs = make_dataset(image_list, labels)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.imgs = imgs
        self.transform = transform
        self.target_transform = target_transform
        self.idx_to_domain = idx_to_domain
        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader
        
        self.labels_to_idx = self.get_labels_to_idx(self.imgs)
        if subsample:
            self.imgs = self.subsample(self.imgs, self.labels_to_idx)
            self.labels_to_idx = self.get_labels_to_idx(self.imgs)
 
    def get_labels_to_idx(self, data):
        labels_to_idx = {}
        for idx, path in enumerate(data):
            label = path[1]
            if label not in labels_to_idx:
                labels_to_idx[label] = [idx]
            else:
                labels_to_idx[label].append(idx)
        return labels_to_idx 

    def subsample(self, data, labels_to_idx):
        np.random.seed(0)
        keep_idx = []
        num_classes = len(labels_to_idx)
        for label in sorted(labels_to_idx.keys()):
            if label < num_classes//2:
                keep_idx.extend(np.random.choice(labels_to_idx[label], int(0.3*len(labels_to_idx[label])), replace=False).tolist())
            else:
                keep_idx.extend(labels_to_idx[label])
        keep_idx = set(keep_idx)
        res = []
        for i in range(len(data)):
            if i in keep_idx:
                res.append(data[i])
        return res

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        domain_id = self.idx_to_domain[index]
        return img, target, index, domain_id

    def __len__(self):
        return len(self.imgs)
