import glob
import os

import numpy as np
import torch
import torch.distributed as dist
import torch.utils.data as data
import torchvision.transforms as transforms
from PIL import Image

from guided_diffusion import logger
from tada.augment_pipes import AugReg, TadaV2

Image.init()


def load_data(
    *,
    data_path,
    batch_size,
    image_size,
    diffusion,
    schedule_sampler,
    class_cond=False,
    random_crop=False,
    random_flip=True,
    seed=None,
    augment_class="none",  # "ada-batch", "cutout", "randaugment", "none
    augment_pipe_kwargs=None,
    return_original=False,
):
    """
    For a dataset, create a generator over (images, kwargs) pairs.

    Each images is an NCHW float tensor, and the kwargs dict contains zero or
    more keys, each of which map to a batched Tensor of their own.
    The kwargs dict can be used for class labels, in which case the key is "y"
    and the values are integer tensors of class labels.

    :param data_dir: a dataset directory.
    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    :param class_cond: if True, include a "y" key in returned dicts for class
                       label. If classes are not available and this is true, an
                       exception will be raised.
    :param deterministic: if True, yield results in a deterministic order.
    :param random_crop: if True, randomly crop the images for augmentation.
    :param random_flip: if True, randomly flip the images for augmentation.
    """
    if not data_path:
        raise ValueError("unspecified data directory")

    augment_pipe = None
    if augment_class == "none":
        transform = create_transform(image_size, random_crop, random_flip, True, True)
    else:
        transform = create_transform(image_size, random_crop, random_flip, False, False)
        if augment_class in ("augreg", "aug-reg", "aug_reg"):
            augment_pipe = AugReg(**augment_pipe_kwargs)
        elif augment_class in ("tada", "tada-v2"):
            augment_pipe = TadaV2(**augment_pipe_kwargs)
            logger.log("augmentations:")
            logger.log(augment_pipe.augmentations)
        else:
            raise ValueError(f"Unknown augment class: {augment_class}")

    dataset = ImageDataset(
        data_path,
        class_cond=class_cond,
        transform=transform,
    )
    dataset = DiffusionDataset(
        dataset=dataset,
        diffusion=diffusion,
        schedule_sampler=schedule_sampler,
        augment_pipe=augment_pipe,
        return_original=return_original,
    )
    dataset = Wrapper(dataset).repeat().shuffle(seed=seed)
    dataloader_kwargs = dict(
        batch_size=batch_size,
        num_workers=min(os.cpu_count() // dist.get_world_size(), 8),
        # num_workers=0, # NOTE for debugging
        pin_memory=True,
    )
    loader = data.DataLoader(dataset, **dataloader_kwargs)
    return iter(loader)


def is_image_file(filename):
    ext = os.path.splitext(filename.lower())[-1]
    return ext in Image.EXTENSION


def get_image_files(data_dir, max_dataset_size=None):
    assert os.path.isdir(data_dir), f"{data_dir} is not a valid directory."
    assert isinstance(max_dataset_size, (int, type(None)))

    paths = glob.glob(os.path.join(data_dir, "**"), recursive=True)
    paths = sorted(filter(is_image_file, paths))
    if not paths:
        raise RuntimeError(
            f"Found 0 images in: {data_dir}\n"
            "Supported image extensions are: " + ",".join(Image.EXTENSION)
        )
    return paths[:max_dataset_size]


def create_transform(resolution, random_crop, hflip, to_tensor, normalize):
    transform = [transforms.Resize(resolution, Image.BICUBIC)]
    if random_crop:
        transform.append(transforms.RandomCrop(resolution))
    else:
        transform.append(transforms.CenterCrop(resolution))
    if hflip:
        transform.append(transforms.RandomHorizontalFlip())
    if to_tensor:
        transform.append(transforms.ToTensor())
    if normalize:
        transform.append(transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3))
    return transforms.Compose(transform)


class ImageDataset(data.Dataset):
    def __init__(
        self,
        data_path,
        class_cond=False,
        transform=None,
    ):
        super().__init__()
        Image.init()
        self.image_paths = get_image_files(data_path)
        logger.log(f"Total images: {len(self.image_paths)}")
        classes = None
        if class_cond:
            # Assume classes are the first part of the filename,
            # before an underscore.
            class_names = [os.path.basename(path).split("_")[0] for path in self.image_paths]
            sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
            classes = np.asarray([sorted_classes[x] for x in class_names])
        self.classes = classes
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        image = Image.open(self.image_paths[idx]).convert("RGB")

        if self.transform is not None:
            image = self.transform(image)

        out_dict = {}
        if self.classes is not None:
            out_dict["y"] = self.classes[idx]
        return image, out_dict


class DiffusionDataset(data.Dataset):
    def __init__(
            self,
            dataset,
            diffusion,
            schedule_sampler,
            augment_pipe,
            return_original=False,
        ):
        self.dataset = dataset
        self.schedule_sampler = schedule_sampler
        self.augment_pipe = augment_pipe

        self.snr = 1.0 / (1.0-diffusion.alphas_cumprod) - 1

        # For analysis
        self.return_original = return_original
        if return_original:
            self.transform = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5]*3, std=[0.5]*3),
            ])
        else:
            self.transform = None

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        image, out_dict = self.dataset[index]
        if self.return_original:
            out_dict["original"] = self.transform(image)
        t, weight = self.schedule_sampler.sample(1, "cpu")
        out_dict["timestep"] = t
        out_dict["weight"] = weight

        if self.augment_pipe is not None:
            snr = self.snr[t]
            image, augment_label = self.augment_pipe(image, snr=snr)
            out_dict["augment_labels"] = augment_label
        if not isinstance(image, torch.Tensor):
            image = torch.from_numpy(image).permute(2, 0, 1).float()
            image = image.contiguous()
        return image, out_dict


class Wrapper(data.IterableDataset):
    def __init__(self, dataset, drop_last=True):
        assert isinstance(dataset, data.Dataset)
        self.dataset = dataset
        self.drop_last = drop_last

        self._count = 1
        self._seed = 0
        self._shuffle = False

    def __len__(self):
        return len(self.dataset) * self._count

    def __iter__(self):
        if dist.is_initialized():
            world_size = dist.get_world_size()
            rank = dist.get_rank()
        else:
            world_size = 1
            rank = 0

        mod = world_size
        shift = rank
        worker_info = data.get_worker_info()
        if worker_info:
            mod *= worker_info.num_workers
            shift = shift * worker_info.num_workers + worker_info.id

        indices = np.arange(len(self.dataset))
        remainder = len(indices) % mod

        epoch = 0
        while epoch < self._count:
            if self._shuffle:
                rng = np.random.default_rng(seed=self._seed+epoch)
                rng.shuffle(indices)

            if remainder == 0:
                order = indices
            elif self.drop_last:
                order = indices[:-remainder]
            else:
                order = np.concatenate((indices, indices[:mod-remainder]))

            for index in order[shift::mod]:
                yield self.dataset[index]

            epoch += 1

    def repeat(self, count=float("inf")):
        self._count = count
        return self

    def shuffle(self, mode=True, seed=None):
        if isinstance(seed, int):
            self._seed = seed
        self._shuffle = mode
        return self
