import os
import os.path as osp
import numpy as np
import torch
import torch.utils.data as data
from PIL import Image
from torch.utils.data import Dataset, DataLoader

def listdir(path, suffix):
    list_path = []
    for root, _, files in os.walk(path, followlinks=True):
        for f in files:
            if f.endswith(suffix):
                list_path.append(osp.join(root, f))
    return list_path

def get_image_dirs(root, dname, split):
    root = root + '{}/{}'.format(dname, split)
    suffix_list = ['png', 'jpg']
    dir_list = []
    for dir_item in os.listdir(root):
        if osp.isdir(osp.join(root, dir_item)):
            dir_list.append(dir_item)
    dir_list.sort()
    impath_list = []
    label_list = []
    for label_id, dir_item in enumerate(dir_list):
        sub_folder = osp.join(root, dir_item)
        for suffix in suffix_list:
            imagedirs_item = listdir(sub_folder, suffix=suffix)
            impath_list += imagedirs_item
            label_list += [label_id for _ in range(len(imagedirs_item))]
    
    impath_label_list = [(impath_list[i], label_list[i]) for i in range(len(impath_list))]
    return impath_label_list

def get_image_dirs_label(root, dname, split, label):
    root = root + '{}/{}'.format(dname, split)
    suffix_list = ['png', 'jpg']
    dir_list = []
    for dir_item in os.listdir(root):
        if osp.isdir(osp.join(root, dir_item)):
            dir_list.append(dir_item)
    dir_list.sort()
    impath_list = []
    label_list = []
    for label_id, dir_item in enumerate(dir_list):
        if label_id == label:
            sub_folder = osp.join(root, dir_item)
            for suffix in suffix_list:
                imagedirs_item = listdir(sub_folder, suffix=suffix)
                impath_list += imagedirs_item
                label_list += [label_id for _ in range(len(imagedirs_item))]
    
    impath_label_list = [(impath_list[i], label_list[i]) for i in range(len(impath_list))]
    return impath_label_list

def get_pacs_image_dirs(root, dname, split):
    root = osp.abspath(root)
    image_dir = osp.join(root, 'images')
    split_dir = osp.join(root, 'splits')

    file = osp.join(
        split_dir, dname + '_' + split + '_kfold.txt'
    )
    impath_label_list = read_split_pacs(file, image_dir)

    return impath_label_list

def get_pacs_image_dirs_label(root, dname, split, label):
    root = osp.abspath(root)
    image_dir = osp.join(root, 'images')
    split_dir = osp.join(root, 'splits')

    file = osp.join(
        split_dir, dname + '_' + split + '_kfold.txt'
    )
    impath_label_list = read_split_pacs_label(file, image_dir, label)

    return impath_label_list

def read_split_pacs(split_file, image_dir):
    items = []

    with open(split_file, 'r') as f:
        lines = f.readlines()

        for line in lines:
            line = line.strip()
            impath, imlabel = line.split(' ')
            impath = osp.join(image_dir, impath)
            imlabel = int(imlabel) - 1
            items.append((impath, imlabel))

    return items

def read_split_pacs_label(split_file, image_dir, label):
    items = []

    with open(split_file, 'r') as f:
        lines = f.readlines()

        for line in lines:
            line = line.strip()
            impath, imlabel = line.split(' ')
            impath = osp.join(image_dir, impath)
            imlabel = int(imlabel) - 1
            if imlabel == label:
                items.append((impath, imlabel))

    return items

class base_dataset(data.Dataset):
    def __init__(self, impath_label_list, transform, dl, labeled):
        self.impath_label_list = impath_label_list
        self.transform = transform
        self.dl = dl
        self.labeled = labeled

    def __len__(self):
        return len(self.impath_label_list)

    def __getitem__(self, index):
        impath, label = self.impath_label_list[index]
        if(self.labeled==False):
            label = -1
        img = Image.open(impath).convert('RGB')
        img = self.transform(img)
        dl = self.dl

        return img, label, dl

    def get_raw_data(self):
        images = []
        clabel = []
        dlabel = []
        for index in range(len(self.impath_label_list)):
            impath, label = self.impath_label_list[index]
            if(self.labeled==False):
                label = -1
            img = Image.open(impath).convert('RGB')
            images.append(img)
            clabel.append(label)
            dlabel.append(self.dl)
        return images, np.array(clabel), np.array(dlabel)

class InfiniteDataLoader(DataLoader):
    def __init__(self, dataset, batch_size=1, num_workers=0, shuffle=True):
        super(InfiniteDataLoader, self).__init__(dataset, batch_size=batch_size, num_workers=num_workers, shuffle=shuffle)
        self.dataset_size = len(dataset)
        self.data_iterator = super().__iter__()

    def __len__(self):
        return float('inf')

    def __iter__(self):
        return self

    def __next__(self):
        try:
            batch = next(self.data_iterator)
        except StopIteration:
            self.data_iterator = super().__iter__()
            batch = next(self.data_iterator)
        return batch

class utilDataset(Dataset):
    '''
    construct pseudo dataset
    input: images_dict.
    '''

    def __init__(self, images_dict, class_labels, domain_labels, lam, transform=None):
        self.x = images_dict  # list of [PIL image]
        self.labels = class_labels  # numpy array
        self.dlabels = domain_labels  # numpy array
        self.lam = lam  # numpy array
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        img = self.x[index]
        img = self.transform(img)
        return img, self.labels[index], self.dlabels[index], self.lam[index]

    def get_raw_data(self):
        return self.x, self.labels, self.dlabels, self.lam