import os
import torch
import numbers
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
from torchvision.datasets import CIFAR10
from datasets.celeba import CelebA
from datasets.ffhq import FFHQ
from datasets.lsun import LSUN
from torch.utils.data import Subset, Dataset, ConcatDataset
from typing import Sequence
import numpy as np
from PIL import Image



class CustomSubset(Dataset):
    r"""
    Subset of a dataset at specified indices.

    Args:
        dataset (Dataset): The whole Dataset
        indices (sequence): Indices in the whole set selected for subset
    """
    dataset: Dataset
    indices: Sequence[int]

    def __init__(self, dataset: Dataset, indices: Sequence[int]) -> None:
        self.dataset = dataset
        self.indices = indices
        self.target_transform = None

    def __getitem__(self, idx):
        if isinstance(idx, list):
            raise RuntimeError('Calling CustomSubset with list of index, may not be expected')
            return self.dataset[[self.indices[i] for i in idx]]
        data, label = self.dataset[self.indices[idx]]
        return data, self.target_transform(label)

    def __len__(self):
        return len(self.indices)


class Crop(object):
    def __init__(self, x1, x2, y1, y2):
        self.x1 = x1
        self.x2 = x2
        self.y1 = y1
        self.y2 = y2

    def __call__(self, img):
        return F.crop(img, self.x1, self.y1, self.x2 - self.x1, self.y2 - self.y1)

    def __repr__(self):
        return self.__class__.__name__ + "(x1={}, x2={}, y1={}, y2={})".format(
            self.x1, self.x2, self.y1, self.y2
        )


def get_dataset_custom_split(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )

    if config.data.dataset == "CIFAR10":
        # if args.subset == False:
        # dataset = CIFAR10(
        #     os.path.join(args.exp, "datasets", "cifar10"),
        #     train=True,
        #     download=True,
        #     transform=tran_transform,
        # )
        dataset, test_dataset = get_dataset(args, config)
        # split the dataset to training and validation
        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train0_indices, train1_indices = (
            indices[: int(num_items * 0.5)],
            indices[int(num_items * 0.5):],
        )
        train0_dataset = CustomSubset(dataset, train0_indices)
        train1_dataset = CustomSubset(dataset, train1_indices)
        # print(len(train0_dataset), len(train1_dataset))
        train0_dataset.target_transform = lambda x: torch.zeros(1).long()
        train1_dataset.target_transform = lambda x: torch.ones(1).long()
        # print(train0_dataset[100])
        # print(train1_dataset[50])
        dataset = ConcatDataset([train0_dataset, train1_dataset])
        # print(dataset[100])
        # print(dataset[25050])

        test_dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10_test"),
            train=False,
            download=True,
            transform=test_transform,
        )

    
    elif config.data.dataset == "CELEBA":
        dataset, test_dataset = get_dataset(args, config)
        # cx = 89
        # cy = 121
        # x1 = cy - 64
        # x2 = cy + 64
        # y1 = cx - 64
        # y2 = cx + 64
        # if config.data.random_flip:
        #     dataset = CelebA(
        #         root=os.path.join(args.exp, "datasets", "celeba"),
        #         split="train",
        #         transform=transforms.Compose(
        #             [
        #                 Crop(x1, x2, y1, y2),
        #                 transforms.Resize(config.data.image_size),
        #                 transforms.RandomHorizontalFlip(),
        #                 transforms.ToTensor(),
        #             ]
        #         ),
        #         download=True,
        #     )
        # else:
        #     dataset = CelebA(
        #         root=os.path.join(args.exp, "datasets", "celeba"),
        #         split="train",
        #         transform=transforms.Compose(
        #             [
        #                 Crop(x1, x2, y1, y2),
        #                 transforms.Resize(config.data.image_size),
        #                 transforms.ToTensor(),
        #             ]
        #         ),
        #         download=True,
        #     )
        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train0_indices, train1_indices = (
            indices[: int(num_items * 0.5)],
            indices[int(num_items * 0.5):],
        )
        train0_dataset = CustomSubset(dataset, train0_indices)
        train1_dataset = CustomSubset(dataset, train1_indices)
        # print(len(train0_dataset), len(train1_dataset))
        train0_dataset.target_transform = lambda x: torch.zeros(1).long()
        train1_dataset.target_transform = lambda x: torch.ones(1).long()
        # print(train0_dataset[100])
        # print(train1_dataset[50])
        dataset = ConcatDataset([train0_dataset, train1_dataset])
        # print(dataset[100])
        # print(dataset[25050])
        # test_dataset = CelebA(
        #     root=os.path.join(args.exp, "datasets", "celeba"),
        #     split="test",
        #     transform=transforms.Compose(
        #         [
        #             Crop(x1, x2, y1, y2),
        #             transforms.Resize(config.data.image_size),
        #             transforms.ToTensor(),
        #         ]
        #     ),
        #     download=True,
        # )
    
    elif config.data.dataset == "LSUN":
        # if args.subset == False:
        # dataset = CIFAR10(
        #     os.path.join(args.exp, "datasets", "cifar10"),
        #     train=True,
        #     download=True,
        #     transform=tran_transform,
        # )
        dataset, test_dataset = get_dataset(args, config)
        # split the dataset to training and validation
        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train0_indices, train1_indices = (
            indices[: int(num_items * 0.5)],
            indices[int(num_items * 0.5):],
        )
        train0_dataset = CustomSubset(dataset, train0_indices)
        train1_dataset = CustomSubset(dataset, train1_indices)
        # print(len(train0_dataset), len(train1_dataset))
        train0_dataset.target_transform = lambda x: torch.zeros(1).long()
        train1_dataset.target_transform = lambda x: torch.ones(1).long()
        # print(train0_dataset[100])
        # print(train1_dataset[50])
        dataset = ConcatDataset([train0_dataset, train1_dataset])
        # print(dataset[100])
        # print(dataset[25050])


    else:
        raise NotImplementedError()

    return dataset, test_dataset

class CustomDataset(Dataset):
    def __init__(self, root_dir, transform = None):
        self.root_dir = root_dir
        self.images = os.listdir(root_dir)
        self.transform = transforms.ToTensor()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.images[idx])
        image = Image.open(img_name).convert('RGB')
        image = self.transform(image)
        return image, 1 #put an arbitrary label on all of the y's
    
def get_dataset(args, config):
    if config.data.random_flip is False:
        tran_transform = test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    else:
        tran_transform = transforms.Compose(
            [
                transforms.Resize(config.data.image_size),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.ToTensor(),
            ]
        )
        test_transform = transforms.Compose(
            [transforms.Resize(config.data.image_size), transforms.ToTensor()]
        )
    if args.custom:

        dataset = CustomDataset(root_dir = args.indices)
        test_dataset = None

    elif config.data.dataset == "CIFAR10":
        dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10"),
            train=True,
            download=True,
            transform=tran_transform,
        )
        test_dataset = CIFAR10(
            os.path.join(args.exp, "datasets", "cifar10_test"),
            train=False,
            download=True,
            transform=test_transform,
        )

    elif config.data.dataset == "CELEBA":
        cx = 89
        cy = 121
        x1 = cy - 64
        x2 = cy + 64
        y1 = cx - 64
        y2 = cx + 64
        if config.data.random_flip:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.RandomHorizontalFlip(),
                        transforms.ToTensor(),
                    ]
                ),
                download=True,
            )
        else:
            dataset = CelebA(
                root=os.path.join(args.exp, "datasets", "celeba"),
                split="train",
                transform=transforms.Compose(
                    [
                        Crop(x1, x2, y1, y2),
                        transforms.Resize(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
                download=True,
            )

        test_dataset = CelebA(
            root=os.path.join(args.exp, "datasets", "celeba"),
            split="test",
            transform=transforms.Compose(
                [
                    Crop(x1, x2, y1, y2),
                    transforms.Resize(config.data.image_size),
                    transforms.ToTensor(),
                ]
            ),
            download=True,
        )

    elif config.data.dataset == "LSUN":
        train_folder = "{}_train".format(config.data.category)
        val_folder = "{}_val".format(config.data.category)
        if config.data.random_flip:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.RandomHorizontalFlip(p=0.5),
                        transforms.ToTensor(),
                    ]
                ),
            )
        else:
            dataset = LSUN(
                root=os.path.join(args.exp, "datasets", "lsun"),
                classes=[train_folder],
                transform=transforms.Compose(
                    [
                        transforms.Resize(config.data.image_size),
                        transforms.CenterCrop(config.data.image_size),
                        transforms.ToTensor(),
                    ]
                ),
            )

        test_dataset = LSUN(
            root=os.path.join(args.exp, "datasets", "lsun"),
            classes=[val_folder],
            transform=transforms.Compose(
                [
                    transforms.Resize(config.data.image_size),
                    transforms.CenterCrop(config.data.image_size),
                    transforms.ToTensor(),
                ]
            ),
        )

    elif config.data.dataset == "FFHQ":
        if config.data.random_flip:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.Compose(
                    [transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor()]
                ),
                resolution=config.data.image_size,
            )
        else:
            dataset = FFHQ(
                path=os.path.join(args.exp, "datasets", "FFHQ"),
                transform=transforms.ToTensor(),
                resolution=config.data.image_size,
            )

        num_items = len(dataset)
        indices = list(range(num_items))
        random_state = np.random.get_state()
        np.random.seed(2019)
        np.random.shuffle(indices)
        np.random.set_state(random_state)
        train_indices, test_indices = (
            indices[: int(num_items * 0.9)],
            indices[int(num_items * 0.9) :],
        )
        test_dataset = Subset(dataset, test_indices)
        dataset = Subset(dataset, train_indices)
    else:
        dataset, test_dataset = None, None


    if args.subset==True and not args.custom:
        subset_size = config.data.subset
        total_size = len(dataset)
        if args.subset_folder != None:
            #if there's not already a predefined subset
            subset_folder = os.path.join(args.exp, "logs", args.subset_folder)
            os.makedirs(subset_folder, exist_ok=True)
            subset_indices = torch.randperm(total_size)[:subset_size]
            subset_dataset = Subset(dataset, subset_indices)
            output_indices_file = os.path.join(subset_folder, 'subset_indices.txt')
            with open(output_indices_file, 'w') as f:
                for idx in subset_indices:
                    f.write(f"{idx}\n")
            output_directory = os.path.join(subset_folder, 'subset_images')
            os.makedirs(output_directory, exist_ok=True)

            for idx, (image, _) in enumerate(subset_dataset):
                image_path = os.path.join(output_directory, f'image_{idx}.png')
                image = (image.permute(1, 2, 0) * 255).numpy().astype(np.uint8)
                image = Image.fromarray(image)
                image.save(image_path)
        else:
            print('indices chosen already')
            indices_file_path = os.path.join(args.exp, "logs", args.indices)
            print(indices_file_path)
            with open(indices_file_path, 'r') as f:
                indices = [int(idx.strip()) for idx in f.readlines()]
            subset_indices = torch.tensor(indices)
        subset_dataset = Subset(dataset, subset_indices)
        return subset_dataset, test_dataset
    
    return dataset, test_dataset


def logit_transform(image, lam=1e-6):
    image = lam + (1 - 2 * lam) * image
    return torch.log(image) - torch.log1p(-image)


def data_transform(config, X):
    if config.data.uniform_dequantization:
        X = X / 256.0 * 255.0 + torch.rand_like(X) / 256.0
    if config.data.gaussian_dequantization:
        X = X + torch.randn_like(X) * 0.01

    if config.data.rescaled:
        X = 2 * X - 1.0
    elif config.data.logit_transform:
        X = logit_transform(X)

    if hasattr(config, "image_mean"):
        return X - config.image_mean.to(X.device)[None, ...]

    return X


def inverse_data_transform(config, X):
    if hasattr(config, "image_mean"):
        X = X + config.image_mean.to(X.device)[None, ...]

    if config.data.logit_transform:
        X = torch.sigmoid(X)
    elif config.data.rescaled:
        X = (X + 1.0) / 2.0

    return torch.clamp(X, 0.0, 1.0)
