import torch
from PIL import Image
from torch.utils.data import Dataset
import torch.nn as nn

def make_dataset(args, image_list, loader, transform=[], labels=[], netF_list=[], netC_list=[]):
    if len(netF_list) >0:
        images = [transform(loader(val.split()[0])) for val in image_list]
        labels = [int(val.split()[1]) for val in image_list]
        output = torch.zeros([1, args.src_num , args.class_num])
        start_test =True
        for input in images:
            for i in range(len(args.src)):
                outputs_test = netC_list[i](netF_list[i](input.cuda().unsqueeze(0)))
                softmax_out = nn.Softmax(dim=1)(outputs_test)
                output[:, i, :] = softmax_out  # (batch, num_src, num_cls)

            if start_test:
                source_preds = output.float()
                start_test = False
            else:
                source_preds = torch.cat((source_preds, output.float()), 0)    
        
        return images, labels, source_preds.detach()
    else:
        images = [transform(loader(val.split()[0])) for val in image_list]
        labels = [int(val.split()[1]) for val in image_list]

        return images, labels



def rgb_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('RGB')

def l_loader(path):
    with open(path, 'rb') as f:
        with Image.open(f) as img:
            return img.convert('L')

class ImageList(Dataset):
    def __init__(self, args, image_list, labels=None, transform=None, target_transform=None, mode='RGB'):
        

        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader

        self.transform = transform
        self.target_transform = target_transform

        self.imgs, self.targets = make_dataset(args, image_list, self.loader, self.transform, labels)

        if len(self.imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

    def __getitem__(self, index):
        target = self.targets[index]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return self.imgs[index], target

    def __len__(self):
        return len(self.imgs)

class ImageList_idx(Dataset):
    def __init__(self, args, image_list, netF_list, netC_list, labels=[], transform=None, target_transform=None, mode='RGB'):
        
        if mode == 'RGB':
            self.loader = rgb_loader
        elif mode == 'L':
            self.loader = l_loader

        self.transform = transform
        self.target_transform = target_transform
        self.netF_list =netF_list
        self.netC_list = netC_list

        self.imgs, self.targets, self.source_preds = make_dataset(args, image_list, self.loader, self.transform, labels, self.netF_list, self.netC_list)
        
        if len(self.imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))


    def __getitem__(self, index):
        target = self.targets[index]
        if self.target_transform is not None:
            target = self.target_transform(target)

        return self.imgs[index], target, self.source_preds[index]
    
    def __len__(self):
        return len(self.imgs)