import torchvision
from torch.utils import data
from torchvision import datasets
from torchvision import transforms
import torch
import os
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from natsort import natsorted
#import torchvision.datasets as dsets
from torchvision.datasets import ImageFolder
import random



## Create a custom Dataset class
class CelebADataset(Dataset):
  def __init__(self, root_dir, transform=None):
    """
    Args:
      root_dir (string): Directory with all the images
      transform (callable, optional): transform to be applied to each image sample
    """
    # Read names of images in the root directory
    image_names = os.listdir(root_dir)

    self.root_dir = root_dir
    self.transform = transform
    self.image_names = natsorted(image_names)

  def __len__(self):
    return len(self.image_names)

  def __getitem__(self, idx):
    # Get the path to the image 
    img_path = os.path.join(self.root_dir, self.image_names[idx])
    # Load image and convert it to RGB
    img = Image.open(img_path).convert('RGB')
    # Apply transformations to the image
    if self.transform:
      img = self.transform(img)
    #img=img.transpose((0, 2, 3, 1))
#    print(img.shape)
    return img
def load_data(args, device):

   print(args.loss)
   if args.loss == "bcesamy":
      image_size = 64
   else:
      image_size = 32



   if "cifar" in args.dataset:
        print("we are in cifar")

        transform_=transforms.Compose([
                transforms.RandomHorizontalFlip(),
                transforms.Resize(size=(image_size, image_size)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])



        dataset = torchvision.datasets.CIFAR10(
            root=args.path_data, train=True, download=False,transform=transform_)
   #     print("num ex")
    #      print(len(dataset))
   #    dataloader = torch.utils.data.DataLoader(
   #     dataset, batch_size=n_batch, shuffle=True, num_workers=4,
   #     drop_last=True)



   elif "lsun" in args.dataset:
       print("loading church data")
       classes="church_outdoor"
       classes = [c + "_train" for c in classes.split(",")]
       #image_size=32
       transform_=transforms.Compose([
                transforms.ToTensor(),
                transforms.Resize(size=(image_size, image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ])

       dataset = torchvision.datasets.LSUN(root=args.path_data, classes=classes,
                                        transform=transform_)
      # dataloader = torch.utils.data.DataLoader(
      #  dataset, batch_size=n_batch, shuffle=True, num_workers=4,
      #  drop_last=True)

   elif "stl10" in args.dataset:
        dataset = datasets.STL10(
            args.path_data, split='unlabeled', download=False,
            transform=transforms.Compose([
                transforms.Resize((image_size, image_size)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
            ]))



   else:
       transform=transforms.Compose([
                                    transforms.Resize(size=(image_size, image_size)),
                                    transforms.RandomHorizontalFlip(),
                                    transforms.ToTensor(),
                                    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), # or Normalize to (0.5, 0.5) without linear_trans
                                ])
        # to -1 and 1
        dataset=CelebADataset(args.path_data, transform)

        print(len(dataset))

   #     num_workers = 0 if device == 'cuda' else 2
        # Whether to put fetched data tensors to pinned memory
   #     pin_memory = True if device == 'cuda' else False

       # dataloader = torch.utils.data.DataLoader(celeba_dataset,
       #                                         batch_size=n_batch,
       #                                         num_workers=num_workers,
       #                                         pin_memory=pin_memory,
       #                                         shuffle=True,drop_last=True)



   num_examples=len(dataset)

   #idx_noise = np.array([random.randint(0,50000-1) for idd in range(n_batch)])

   #trainset_1 = torch.utils.data.Subset(dataset, idx_noise)

   #dataloader = torch.utils.data.DataLoader(trainset_1, batch_size=n_batch,
   #                                          shuffle=True, num_workers=1, drop_last=True)
   dataloader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size, shuffle=True, num_workers=1,#2#4
        drop_last=True)


   dataloader2 = torch.utils.data.DataLoader(
        dataset, batch_size=1024, shuffle=True, num_workers=1,#2#4
        drop_last=True)

   print(len(dataloader))
   #bfechedb 


   return dataloader, dataloader2, num_examples
 
