import math
import os
import random
import logging
import blobfile as bf
import numpy as np
from PIL import Image
#from mpi4py import MPI
from torch.utils.data import DataLoader, Dataset


def load_data(
        *,
        data_dir,
        batch_size,
        image_size,
        class_cond=False,
        deterministic=False,
        random_crop=False,
        random_flip=True,
        in_channels=3
):
    """
    For a dataset, create a generator over (images, kwargs) pairs.

    Each images is an NCHW float tensor, and the kwargs dict contains zero or
    more keys, each of which map to a batched Tensor of their own.
    The kwargs dict can be used for class labels, in which case the key is "y"
    and the values are integer tensors of class labels.

    :param data_dir: a dataset directory.
    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    :param class_cond: if True, include a "y" key in returned dicts for class
                       label. If classes are not available and this is true, an
                       exception will be raised.
    :param deterministic: if True, yield results in a deterministic order.
    :param random_crop: if True, randomly crop the images for augmentation.
    :param random_flip: if True, randomly flip the images for augmentation.
    :param in_channels: new parameter in DDIBs as we experimented with grayscale
                        images
    """
    if not data_dir:
        raise ValueError("unspecified data directory")
    all_files = _list_image_files_recursively(data_dir)
    classes = None
    if class_cond:
        # Assume classes are the first part of the filename,
        # before an underscore.
        class_names = [bf.basename(path).split("_")[0] for path in all_files]
        sorted_classes = {x: i for i, x in enumerate(sorted(set(class_names)))}
        classes = [sorted_classes[x] for x in class_names]
    print(f'the class is {classes}')
    #dataset = ImageDataset(
    #    image_size,
    #    all_files,
    #    classes=classes,
    #    shard=MPI.COMM_WORLD.Get_rank(),
    #    num_shards=MPI.COMM_WORLD.Get_size(),
    #    random_crop=random_crop,
    #    random_flip=random_flip,
    #    in_channels=in_channels
    #)
    dataset = ImageDataset(
        image_size,
        all_files,
        classes=classes,
        random_crop=random_crop,
        random_flip=random_flip,
        in_channels=in_channels
    )
    if deterministic:
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=True
        )
    else:
        loader = DataLoader(
            dataset, batch_size=batch_size, shuffle=True, num_workers=1, drop_last=True
        )
    while True:
        yield from loader


def load_source_data_for_domain_translation(
        *,
        batch_size,
        image_size,
        data_dir="./experiments/imagenet",
        in_channels=3,
        class_cond=True
):
    """
    This function is new in DDIBs: loads the source dataset for translation.
    For the dataset, create a generator over (images, kwargs) pairs.
    No image cropping, flipping or shuffling.

    :param batch_size: the batch size of each returned pair.
    :param image_size: the size to which images are resized.
    """
    if not data_dir:
        raise ValueError("unspecified data directory")
    all_files = [f for f in list_image_files(data_dir) if "translated" not in f]
    # Classes are the first part of the filename, before an underscore: e.g. "291_1.png"
    classes = None
    if class_cond:
        classes = [int(bf.basename(path).split("_")[0]) for path in all_files]
    dataset = ImageDataset(
        image_size,
        all_files,
        in_channels=in_channels,
        random_flip=False,
        classes=classes,
        filepaths=all_files,
    )
    #dataset = ImageDataset(
    #    image_size,
    #    all_files,
    #    classes=classes,
    #    shard=MPI.COMM_WORLD.Get_rank(),
    #    num_shards=MPI.COMM_WORLD.Get_size(),
    #    random_crop=random_crop,
    #    random_flip=random_flip,
    #    in_channels=in_channels
    #)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=1)
    yield from loader


def list_image_files(data_dir):
    """List images files in the directory (not recursively)."""
    files = sorted(bf.listdir(data_dir))
    #print(f'the whole files is {files}')
    results = []
    for entry in files:
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
            results.append(full_path)
    return results

## list a file into a list of images, input "val"
def _list_image_files_recursively(data_dir):
    results = []
    for entry in sorted(bf.listdir(data_dir)):
        full_path = bf.join(data_dir, entry)
        ext = entry.split(".")[-1]
        if "." in entry and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
            results.append(full_path)
        elif bf.isdir(full_path):
            results.extend(_list_image_files_recursively(full_path))
    return results


class ImageDataset(Dataset):
    def __init__(
            self,
            resolution,
            image_paths,
            classes=None,
            shard=0,
            num_shards=1,
            random_crop=False,
            random_flip=True,
            in_channels=3,
            filepaths=None
    ):
        super().__init__()
        self.resolution = resolution
        self.local_images = image_paths[shard:][::num_shards]
        self.local_classes = None if classes is None else classes[shard:][::num_shards]
        self.random_crop = random_crop
        self.random_flip = random_flip
        self.in_channels = in_channels
        self.filepaths = filepaths

    def __len__(self):
        return len(self.local_images)

    def __getitem__(self, idx):
        path = self.local_images[idx]
        with bf.BlobFile(path, "rb") as f:
            pil_image = Image.open(f)
            pil_image.load()
        if self.in_channels == 3:
            pil_image = pil_image.convert("RGB")
        else:
            pil_image = pil_image.convert("L")

        if self.random_crop:
            arr = random_crop_arr(pil_image, self.resolution)
        else:
            arr = center_crop_arr(pil_image, self.resolution)
        if len(arr.shape) < 3:
            arr = arr[:, :, np.newaxis]  # Adds a single channel

        if self.random_flip and random.random() < 0.5:
            arr = arr[:, ::-1]

        arr = arr.astype(np.float32) / 127.5 - 1

        out_dict = dict()
        if self.local_classes is not None:
            out_dict["y"] = np.array(self.local_classes[idx], dtype=np.int64)
        if self.filepaths is not None:
            out_dict["filepath"] = self.filepaths[idx]
        return np.transpose(arr, [2, 0, 1]), out_dict