import torch
from torchvision import datasets, transforms
import torch.utils.data
import matplotlib.pyplot as plt
import numpy as np
import torch
import os
from PIL import Image
import torchvision
import h5py


class CelebAHQTrain(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        self.data_root = data_root
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.data_root, str(index+1).zfill(5) + '.jpg'))
        return {'img': self.transform(img)}

    def __len__(self):
        return 28000


class CelebAHQTest(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        self.data_root = data_root
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        img = Image.open(os.path.join(self.data_root, str(index+28001).zfill(5) + '.jpg'))
        return {'img': self.transform(img)}

    def __len__(self):
        return 2000


class CelebAWildTrain(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba_wild.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['train_img'][...])
            self.keypoints = torch.from_numpy(hf['train_landmark'][...])
        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255),
                  'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


class MAFLWildTrain(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba_wild.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['mafl_train_img'][...])
            self.keypoints = torch.from_numpy(hf['mafl_train_landmark'][...])
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255),
                  'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


class MAFLWildTest(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba_wild.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['mafl_test_img'][...])
            self.keypoints = torch.from_numpy(hf['mafl_test_landmark'][...])
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255),
                  'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


class Bedroom(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        self.file_paths = [os.path.join(data_root, file_name) for file_name in os.listdir(data_root)
                           if file_name.endswith('.jpg')]

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        img = Image.open(self.file_paths[index])
        return {'img': self.transform(img)}

    def __len__(self):
        return len(self.file_paths)


class FFHQHQ(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

        self.imgs = torchvision.datasets.ImageFolder(root=data_root, transform=transform)

    def __getitem__(self, index):
        return {'img': self.imgs[index][0]}

    def __len__(self):
        return len(self.imgs)


class BBCPoseTrain(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        data_root = os.path.join(data_root, 'train_images')
        self.file_paths = [os.path.join(data_root, file_name) for file_name in os.listdir(data_root)
                           if file_name.endswith('.jpg')]

        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.ToTensor(),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, index):
        img = Image.open(self.file_paths[index])
        return {'img': self.transform(img)}

    def __len__(self):
        return len(self.file_paths)


class CelebA(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['celeba_wo_mafl'][...])

        self.transform = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255)}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


class MAFLTrain(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['mafl_train_data'][...])
            self.keypoints = torch.from_numpy(hf['mafl_train_label'][...]) * image_size
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255),
                  'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


class MAFLTest(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size):
        super().__init__()
        data_file = 'celeba.h5'
        with h5py.File(os.path.join(data_root, data_file), 'r') as hf:
            self.imgs = torch.from_numpy(hf['mafl_test_data'][...])
            self.keypoints = torch.from_numpy(hf['mafl_test_label'][...]) * image_size
        self.transform = transforms.Compose([
            transforms.Resize(image_size),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ])

    def __getitem__(self, idx):
        sample = {'img': self.transform(self.imgs[idx].float() / 255),
                  'keypoints': self.keypoints[idx]}
        return sample

    def __len__(self):
        return self.imgs.shape[0]


def get_dataset(data_root, image_size, class_name='all'):
    if class_name == 'celebaHQ':
        return CelebAHQTrain(data_root, image_size)
    elif class_name == 'celebaHQTest':
        return CelebAHQTest(data_root, image_size)
    elif class_name == 'ffhqHQ':
        return FFHQHQ(data_root, image_size)
    elif class_name == 'bedroom':
        return Bedroom(data_root, image_size)
    elif class_name == 'bbcpose':
        return BBCPoseTrain(data_root, image_size)
    elif class_name == 'celeba_wild':
        return CelebAWildTrain(data_root, image_size)
    elif class_name == 'mafl_wild_train':
        return MAFLWildTrain(data_root, image_size)
    elif class_name == 'mafl_wild_test':
        return MAFLWildTest(data_root, image_size)
    elif class_name == 'celeba':
        return CelebA(data_root, image_size)
    elif class_name == 'mafl_train':
        return MAFLTrain(data_root, image_size)
    elif class_name == 'mafl_test':
        return MAFLTest(data_root, image_size)
    else:
        raise ValueError


def get_dataloader(data_root, class_name, image_size, batch_size, num_workers=6, pin_memory=True, drop_last=True):
    dataset = get_dataset(data_root=data_root, image_size=image_size, class_name=class_name)
    return torch.utils.data.DataLoader(dataset,
                                       batch_size=batch_size, shuffle=True,
                                       num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last)
