import sys
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import time
from torch.utils.data import Dataset, DataLoader
import torchvision
import glob
import os
from PIL import Image

attr_visible  = [4, 5, 8, 9, 11, 12, 15, 17, 18, 20, 21, 22, 26, 28, 31, 32, 33, 35]

class CelebAHQMaskDS(Dataset):
    def __init__(self, size=128, datapath='/home/******/work/mld/data/data_celba/CelebAMask-HQ/',  train = True, all_mod ="all"):
        """
            Args: 
                datapath: folder path containing train, val, and test folders of images and mask and celeba attribute text file
                transform: torchvision transform for the images and masks
                ds_type: train, val, or test
        """

        super().__init__()
        self.size = size
        self.all_mod =all_mod
        self.datapath = datapath
        self.transform = torchvision.transforms.Compose([torchvision.transforms.Resize([size, size]),
                            torchvision.transforms.ToTensor(),
                            ])

        if train:
            split ="train"
        else:
            split = "test"
        d_image ="/home/******/work/mld/data/data_celba/CelebAMask-HQ/"+split+"/CelebA-HQ-img"
        d_mask = "/home/******/work/mld/data/data_celba/CelebAMask-HQ/"+split+"CelebAMaskHQ-mask"
  

        self.img_files = [d_image + "/"+ p for p in os.listdir(d_image) ]
        self.mask_files = [d_mask + "/"+ p for p in os.listdir(d_mask) ]
        self.attr_tensor = torch.load("/home/******/work/mld/data/data_celba/CelebAMask-HQ/"+split+"/att.pth")[:, attr_visible]

        self.img_files.sort()
        self.mask_files.sort()
        assert len(self.img_files) == len(self.mask_files)
        
        


    def __len__(self):
        return len(self.img_files)

    def __getitem__(self, index):
        """
        Returns a tuple of image, mask, attribute
        """
        if self.all_mod =="all":
            im = self.transform(Image.open(self.img_files[index]))
            mask = self.transform(Image.open(self.mask_files[index]))
            return { "image" : im, "mask" :mask,"attributes": self.attr_tensor[index]}, True
        else:
            if self.all_mod =="image":
                return { "image" : self.transform(Image.open(self.img_files[index]))}, True
            elif self.all_mod =="attributes":
                return { "attributes" : self.attr_tensor[index]}, True
            elif self.all_mod =="mask":
                return { "mask" : self.transform(Image.open(self.mask_files[index] ))}, True
            



def read_tensor(file_path):
    return torch.load(file_path)


class Dataset_latent(Dataset):
    def __init__(self, folder = "/home/******/work/mld/data/data_celba/latent/",train =True,im = 128):
            if train:
                self.imgs = read_tensor(folder+"/train/"+"image_{}.pth".format(im))  
                self.mask = read_tensor(folder+"/train/"+"mask_64.pth") 
                self.att = read_tensor(folder+"/train/"+"att_16.pth") 
            else:
                self.imgs = read_tensor(folder+"/test/"+"image_{}.pth".format(im))  
                self.mask = read_tensor(folder+"/test/"+"mask_64.pth") 
                self.att = read_tensor(folder+"/test/"+"att_16.pth") 

    def __getitem__(self, i):
        
        return { "image" : self.imgs [i], "mask" :self.mask [i],"attributes": self.att[i]}

    def __len__(self):
        return self.imgs.size(0) 
    





class Dataset_latent_2(Dataset):
    def __init__(self, folder = "/home/******/work/mld/data/data_celba/latent_2/",train =True,im = 256):
            if train:
                self.imgs = read_tensor(folder+"/train/"+"image_{}.pth".format(im))  
                self.mask = read_tensor(folder+"/train/"+"mask_128.pth") 
                self.att = read_tensor(folder+"/train/"+"att_32.pth") 
            else:
                self.imgs = read_tensor(folder+"/test/"+"image_{}.pth".format(im))  
                self.mask = read_tensor(folder+"/test/"+"mask_128.pth") 
                self.att = read_tensor(folder+"/test/"+"att_32.pth") 

    def __getitem__(self, i):
        
        return { "image" : self.imgs [i], "mask" :self.mask [i],"attributes": self.att[i]}

    def __len__(self):
        return self.imgs.size(0)