import torch
from torchvision import transforms
import torchvision
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from my_utils import *
import numpy as np
from PIL import Image
import sys
import os


class Dataset_loader(torch.utils.data.Dataset):
    def __init__(self, dataset = 'mnist', size=(32,32), c = 1, quantize = False):



        if dataset == 'cars_sdf' or dataset == 'cars' or dataset == 'ellipses_train' or \
        dataset == 'ellipses_test' or dataset == 'CT_sin_test' or dataset == 'CT_x_test' or \
        dataset == 'CT_sin_train':

            self.transform = transforms.Compose([
            transforms.Resize(size),
            transforms.Grayscale(),
            # transforms.RandomRotation(degrees=(0, 180)),
            transforms.ToTensor()])

        else:
            self.transform = transforms.Compose([
                transforms.Resize(size),
                transforms.ToTensor(),
            ])



        self.c = c
        self.dataset = dataset
        if self.dataset == 'mnist':
            self.img_dataset = torchvision.datasets.MNIST('data/MNIST', train=True,
                                                    download=True)
        
        elif self.dataset == 'celeba-hq_train':
            celeba_path = 'Projects/datasets/celeba_hq/celeba_hq_1024_train/'
            self.img_dataset = ImageFolder(celeba_path, self.transform)

        elif self.dataset == 'celeba-hq_test':
            celeba_path = '/raid/ /Projects/datasets/celeba_hq/celeba_hq_1024_test/'
            self.img_dataset = ImageFolder(celeba_path, self.transform)
            
            
        elif self.dataset == 'cars':
            car_path = '/raid/ /Projects/datasets/ETH/ETH_car/'
            self.img_dataset = ImageFolder(car_path, self.transform)

        elif self.dataset == 'cars_sdf':
            car_path = '/raid/ /Projects/datasets/ETH/ETH_car_sdf/'
            self.img_dataset = ImageFolder(car_path, self.transform)

        elif self.dataset == 'lsun':
            lsun_path = ' /Projects/datasets/lsun-master/lsun'
            lsun_class = ['bedroom_val']
            self.img_dataset = torchvision.datasets.LSUN(lsun_path, classes=lsun_class,
                              transform=self.transform)

        elif self.dataset == 'ellipses_train':
            ellipses_path = ' /Projects/datasets/ellipses/ellipses_256_train/'
            self.img_dataset = ImageFolder(ellipses_path, self.transform)

        elif self.dataset == 'ellipses_test':
            ellipses_path = ' /Projects/datasets/ellipses/ellipses_256_test/'
            self.img_dataset = ImageFolder(ellipses_path, self.transform)

        elif self.dataset == 'CT_sin_train':
            CT_path = ' /Projects/datasets/CT_dataset/new/sin_train_uniform_snr_40/'
            self.img_dataset = ImageFolder(CT_path, self.transform)

        elif self.dataset == 'CT_sin_test':
            CT_path = ' /Projects/datasets/CT_dataset/new/sin_test_uniform_snr_40/'
            self.img_dataset = ImageFolder(CT_path, self.transform)


        elif self.dataset == 'CT_x_test':
            CT_path = ' /Projects/datasets/CT_dataset/new/x_test_uniform_snr_40/'
            self.img_dataset = ImageFolder(CT_path, self.transform)

            
        self.meshgrid = get_mgrid(size[0])
        self.im_size = size
        self.quantize = quantize

    def __len__(self):
        return len(self.img_dataset)

    def __getitem__(self, item):
        img = self.img_dataset[item][0]
        if not self.dataset == 'mnist':
            img = transforms.ToPILImage()(img)

        img = self.transform(img).permute([1,2,0])
        img = img.reshape(-1, self.c)

        if self.quantize:
            img = img * 255.0
            img = torch.multiply(8.0, torch.div(img , 8 , rounding_mode = 'floor'))
            img = img/255.0
        
        return img
    
  


class CT_sinogram(torch.utils.data.Dataset):

    def __init__(self, directory):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.directory= directory
        self.name_list = os.listdir(self.directory)
        self.transform = transforms.Compose([
            transforms.Grayscale(),
            transforms.ToTensor()])

    def __len__(self):
        return len(self.name_list)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()

        file_name = self.name_list[idx]
        sample = np.load(os.path.join(self.directory,file_name))
        sample = torch.tensor(np.expand_dims(sample , axis = 0)/100.0, dtype = torch.float32)
        sample = sample.permute([1,2,0]).reshape(-1, 1)

        return sample