import os
import torch

from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor, as_completed
from torchvision.io import read_image
from torch.utils.data import DataLoader, random_split, Dataset
from torchvision import transforms
from lightning import LightningDataModule


class CelebA_Dataset(Dataset):
    def __init__(self, 
                data_dir=None,
                hard_load=True,
                **kwargs):
        data_dir = 'your_save_path' if data_dir is None else data_dir
        
        self.hard_load = hard_load
        self.transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.RandomHorizontalFlip(),
        ])
        
        self.image_files = [os.path.join(root_dir, f) for f in os.listdir(root_dir)]
        if hard_load:
            data_list = []
            with ThreadPoolExecutor(max_workers=16) as executor:
                futures = [executor.submit(self.preprocess, img_path) for img_path in self.image_files]
            
                for future in tqdm(as_completed(futures), total=len(futures)):
                    img_tensor = future.result()
                    if img_tensor is not None:
                        data_list.append(img_tensor)

            self.data = torch.stack(data_list)

    def preprocess(self, img_path):
        img = read_image(img_path)
        img = self.transform(img)
        img = (img - 127.5) / 128 
        return img

    
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        if self.hard_load:
            image = self.data[idx]
        else:
            img_path = self.image_files[idx]
            image = self.preprocess(img_path)

        return image
    

class CelebA_Loader(LightningDataModule):
    def __init__(self, data_dir=None, val_rate=0.05, batch_size=64, num_workers=0, hard_load=False):
        super().__init__()
        self.batch_size = int(batch_size)
        self.num_workers = num_workers
        dataset = CelebA_Dataset(root_dir=data_dir, hard_load=hard_load)
        num = dataset.__len__()
        n_train = num
        n_valid = round(num * val_rate)

        print('\nBatch size: {}'.format(batch_size))
        print('Total number of images {}.'.format(num))
        print('\tTraining files:', n_train)
        print('\tValidation files:', n_valid)

        self.trSamples, self.vlSamples = random_split(dataset, lengths=[num - n_valid, n_valid])

    def train_dataloader(self):
        return DataLoader(self.trSamples, 
                          batch_size=self.batch_size, 
                          shuffle=True, 
                          num_workers=self.num_workers,
                          persistent_workers=True,
                          prefetch_factor=8,
                          pin_memory=True,
                          drop_last=True)

    def val_dataloader(self):
        return DataLoader(self.vlSamples, 
                          batch_size=self.batch_size, 
                          num_workers=self.num_workers,
                          persistent_workers=True,
                          prefetch_factor=8,
                          drop_last=True,
                          pin_memory=True)
