import os
import pickle

import numpy as np
import paddle
import csv
from PIL import Image
from collections import namedtuple
from functools import partial
from paddle.io import Dataset

from .prefetch import prefetch_transform


class CIFAR10(Dataset):
    """CIFAR-10 Dataset.

    Args:
        root (string): Root directory of dataset.
        transform (callable, optional): A function/transform that takes in an PIL image and returns
            a transformed version.
        train (bool): If True, creates dataset from training set, otherwise creates from test set
            (default: True).
        prefetch (bool): If True, remove ``ToTensor`` and ``Normalize`` in
            ``transform["remaining"]``, and turn on prefetch mode (default: False).
    """

    def __init__(self, root, transform=None, train=True, prefetch=False):
        self.train = train
        self.pre_transform = transform["pre"]
        self.primary_transform = transform["primary"]
        if prefetch:
            self.remaining_transform, self.mean, self.std = prefetch_transform(
                transform["remaining"]
            )
        else:
            self.remaining_transform = transform["remaining"]
        mode = "train" if train else "test"
        if train:
            data_list = [
                "data_batch_1",
                "data_batch_2",
                "data_batch_3",
                "data_batch_4",
                "data_batch_5",
            ]
        else:
            data_list = ["test_batch"]
        # self.dataset = paddle.vision.datasets.Cifar10(mode=mode)
        # self.data = 
        self.prefetch = prefetch
        data = []
        targets = []
        root = os.path.expanduser(root)
        for file_name in data_list:
            file_path = os.path.join(root, file_name)
            with open(file_path, "rb") as f:
                entry = pickle.load(f, encoding="latin1")
            # with open(entry["data"], "rb") as f:
            #     img = np.array(Image.open(f).convert("RGB"))
            data.append(entry["data"])
            targets.extend(entry["labels"])
        # Convert data (List) to NHWC (np.ndarray) works with PIL Image.
        data = np.vstack(data).reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
        self.data = data
        self.targets = np.asarray(targets)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        img = Image.fromarray(img)  ## HWC ndarray->HWC Image.
        # Pre-processing transformations (HWC Image->HWC Image).
        # img = paddle.to_tensor(img)
        if self.pre_transform is not None:
            img = self.pre_transform(img)
        # Primary transformations (HWC Image->HWC Image).
        img = self.primary_transform(img)
        # The remaining transformations (HWC Image->CHW tensor).
        img = self.remaining_transform(img)
        if self.prefetch:
            # HWC ndarray->CHW tensor with C=3.
            img = np.rollaxis(np.array(img, dtype=np.uint8), 2)
            # img = paddle.to_tensor(img)
        img = np.array(img)
        target = np.array(target)
        item = {"img": img, "target": target}

        return item

    def __len__(self):
        return len(self.data)

class MNIST(Dataset):
    def __init__(self, root, transform=None, train=True, prefetch=False):
        self.train = train
        self.pre_transform = transform["pre"]
        self.primary_transform = transform["primary"]
        if prefetch:
            self.remaining_transform, self.mean, self.std = prefetch_transform(
                transform["remaining"]
            )
        else:
            self.remaining_transform = transform["remaining"]
        root = os.path.expanduser(root)
        self.prefetch = prefetch
        mode = 'train' if train else 'test'
        self.dataset = paddle.vision.datasets.MNIST(mode=mode)
        targets = self.dataset.labels
        self.targets = np.array(targets).squeeze()
        data = self.dataset.images
        self.data = np.reshape(data, [-1, 28, 28])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        img, target = self.data[index], self.targets[index]
        # Pre-processing transformations (HWC Image->HWC Image).
        if self.pre_transform is not None:
            img = self.pre_transform(img)
        # Primary transformations (HWC Image->HWC Image).
        img = self.primary_transform(img)
        # The remaining transformations (HWC Image->CHW tensor).
        img = self.remaining_transform(img)
        if self.prefetch:
            # HWC ndarray->CHW tensor with C=3.
            img = np.rollaxis(np.array(img, dtype=np.uint8), 2)
            # img = paddle.to_tensor(img)
        img = np.array(img)
        target = np.array(target)
        item = {"img": img, "target": target}

        return item

class GTSRB(Dataset):
    def __init__(self, root, transform=None, train=True, prefetch=False):
        self.train = train
        self.pre_transform = transform["pre"]
        self.primary_transform = transform["primary"]
        if prefetch:
            self.remaining_transform, self.mean, self.std = prefetch_transform(
                transform["remaining"]
            )
        else:
            self.remaining_transform = transform["remaining"]
        root = os.path.expanduser(root)
        self.prefetch = prefetch
        if train:
            self.data_folder = os.path.join(root, "Train")
            self.data, self.targets = self._get_data_train_list()
        else:
            self.data_folder = os.path.join(root, "Test")
            self.data, self.targets = self._get_data_test_list()
        self.targets = np.asarray(self.targets)

    def _get_data_train_list(self):
        images = []
        labels = []
        for c in range(0, 43):
            prefix = self.data_folder + "/" + format(c, "05d") + "/"
            gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv")
            gtReader = csv.reader(gtFile, delimiter=";")
            next(gtReader)
            for row in gtReader:
                images.append(self.pre_transform(Image.open(prefix + row[0])))
                labels.append(int(row[7]))
            gtFile.close()
        images = np.array(images)
        return images, labels

    def _get_data_test_list(self):
        images = []
        labels = []
        prefix = os.path.join(self.data_folder, "GT-final_test.csv")
        gtFile = open(prefix)
        gtReader = csv.reader(gtFile, delimiter=";")
        next(gtReader)
        for row in gtReader:
            images.append(self.pre_transform(Image.open(self.data_folder + "/" + row[0])))
            labels.append(int(row[7]))
        images = np.array(images)
        return images, labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        target = self.targets[index]
        img = self.data[index]
        # Pre-processing transformations (HWC Image->HWC Image).
        # if self.pre_transform is not None:
        #     img = self.pre_transform(img)
        # Primary transformations (HWC Image->HWC Image).
        img = self.primary_transform(img)
        # The remaining transformations (HWC Image->CHW tensor).
        img = self.remaining_transform(img)
        if self.prefetch:
            # HWC ndarray->CHW tensor with C=3.
            img = np.rollaxis(np.array(img, dtype=np.uint8), 2)
            # img = paddle.to_tensor(img)
        img = np.array(img)
        target = np.array(target)
        item = {"img": img, "target": target}

        return item

CSV = namedtuple("CSV", ["header", "index", "data"])
class CelebA(Dataset):
    def __init__(
            self,
            root,
            split = "train",
            target_type = "attr",
            transform = None,
            target_transform = None,
    ):
        self.list_attributes = [18, 31, 21]
        self.root = root
        self.base_folder = ''
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
        }
        split_ = split_map[split.lower()]
        splits = self._load_csv("list_eval_partition.txt")
        identity = self._load_csv("identity_CelebA.txt")
        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
        attr = self._load_csv("list_attr_celeba.txt", header=1)

        mask = (splits.data == split_).squeeze()

        self.filename = [splits.index[i] for i in paddle.squeeze(paddle.nonzero(mask))]
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
        self.attr = paddle.floor_divide(self.attr + 1, paddle.to_tensor([2], dtype=paddle.int64))
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename,
        header = None,
    ):
        data, indices, headers = [], [], []

        fn = partial(os.path.join, self.root, self.base_folder)
        with open(fn(filename)) as csv_file:
            data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))

        if header is not None:
            headers = data[header]
            data = data[header + 1:]

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, paddle.to_tensor(data_int))

    def __getitem__(self, index):
        X = Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        target: Any = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError("Target type \"{}\" is not recognized.".format(t))

        if self.transform is not None:
            X = self.transform(X)

        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None
        return X, target

    def __len__(self):
        return len(self.attr)

class CelebA_attr(paddle.io.Dataset):
    def __init__(self, root, transform=None, train=True, prefetch=False):
        split = "train" if train else "test"
        self.pre_transform = transform["pre"]
        self.primary_transform = transform["primary"]
        if prefetch:
            self.remaining_transform, self.mean, self.std = prefetch_transform(
                transform["remaining"]
            )
        else:
            self.remaining_transform = transform["remaining"]
        root = os.path.expanduser(root)
        self.prefetch = prefetch
        self.train = train
        self.pre_transform = transform["pre"]
        self.dataset = CelebA(root=root, split=split, target_type="attr", transform=self.pre_transform)
        self.list_attributes = [18, 31, 21]
        self.split = split
        data = []
        targets = []
        for index in range(len(self.dataset)):
            img, target = self.dataset[index]
            bool_attributes = target[self.list_attributes].cpu().numpy()
            target =self._convert_attributes(bool_attributes)
            data.append(img)
            targets.append(target)
        self.data = np.array(data)
        self.targets = np.array(targets)

    def _convert_attributes(self, bool_attributes):
        return bool_attributes[0] * 4 + bool_attributes[1] * 2 + bool_attributes[2] 
        # return (bool_attributes[0] << 2) + (bool_attributes[1] << 1) + (bool_attributes[2])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        # Pre-processing transformations (HWC Image->HWC Image).
        # if self.pre_transform is not None:
        #     img = self.pre_transform(img)
        target = self.targets[index]
        img = self.data[index]
        # Primary transformations (HWC Image->HWC Image).
        img = self.primary_transform(img)
        # The remaining transformations (HWC Image->CHW tensor).
        img = self.remaining_transform(img)
        if self.prefetch:
            # HWC ndarray->CHW tensor with C=3.
            img = np.rollaxis(np.array(img, dtype=np.uint8), 2)
            # img = paddle.to_tensor(img)
        img = np.array(img)
        target = np.array(target)
        item = {"img": img, "target": target}

        return item
