import os
import csv
import random
from collections import namedtuple
import csv
import PIL
from functools import partial
import paddle.vision.transforms as transforms
import numpy as np
from typing import Any, Callable, List, Optional, Union, Tuple
import paddle
from paddle.io import Dataset


from PIL import Image



class ToNumpy:
    def __call__(self, x):
        x = np.array(x)
        if len(x.shape) == 2:
            x = np.expand_dims(x, axis=2)
        return x


class ProbTransform(paddle.nn.Layer):
    def __init__(self, f, p=1):
        super(ProbTransform, self).__init__()
        self.f = f
        self.p = p

    def forward(self, x):  # , **kwargs):
        if random.random() < self.p:
            return self.f(x)
        else:
            return x



class DictDataset(paddle.io.Dataset):
    def __init__(self, data_map, input_transform=None):
        """
        @Args:
            data_map (dict):
                A dictionary of datas, e.g., {'input':np.ndarray, 'target': np.ndarray, ...}
        """
        self.data_map = data_map
        self.input_transform = input_transform

    def __getitem__(self, idx):
        sample = {}
        for key in self.data_map.keys():
            sample[key] = self.data_map[key][idx]
            if self.input_transform and key == 'input':
                img = paddle.to_tensor(sample[key])#.permute(1, 2, 0)
                sample[key] = self.input_transform(img)#.permute(2, 0, 1)
        return sample

    def __len__(self):
        key = list(self.data_map.keys())[0]
        return len(self.data_map[key])
        

class DictDatasetWrapper(paddle.io.Dataset):
    def __init__(self, dataset, field_names):
        """
        Wrapper an iterable dataset to a DictDataset

        @Args
            dataset (paddle.io.Dataset):
                an iterable dataset
            field_names (list):
                each item corresponds to the name of the data
        """
        self.dataset = dataset 
        self.field_names = field_names
    
    def __getitem__(self, idx):
        data_fields = self.dataset[idx]
        return dict(zip(self.field_names, data_fields))

    def __len__(self):
        return len(self.dataset)


def get_transform(opt, train=True, attack=False, pretensor_transform=False):
    transforms_list = []
    transforms_list.append(transforms.Resize((opt.input_height, opt.input_width)))
    if pretensor_transform:
        if train and not attack:
            transforms_list.append(transforms.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop))
            # transforms_list.append(transforms.RandomRotation(opt.random_rotation))
            # transforms_list.append(transforms.RandomHorizontalFlip())
            if opt.dataset == "cifar10":
                transforms_list.append(transforms.RandomHorizontalFlip(prob=0.5))

    transforms_list.append(transforms.ToTensor())
    if opt.dataset == "cifar10":
        transforms_list.append(transforms.Normalize([0.4914, 0.4822, 0.4465], [0.247, 0.243, 0.261]))
    elif opt.dataset == "mnist":
        transforms_list.append(transforms.Normalize([0.5], [0.5]))
    elif opt.dataset == "gtsrb" or opt.dataset == "celeba":
        pass
    else:
        raise Exception("Invalid Dataset")
    return transforms.Compose(transforms_list)


class PostTensorTransform(paddle.nn.Layer):
    def __init__(self, opt, p=0.5):
        super(PostTensorTransform, self).__init__()
        # self.random_crop = ProbTransform(
        #     A.RandomCrop((opt.input_height, opt.input_width), padding=opt.random_crop), p=0.8
        # )
        self.random_rotation = ProbTransform(transforms.RandomRotation(opt.random_rotation), p)
        if opt.dataset == "cifar10":
            self.random_horizontal_flip = transforms.RandomHorizontalFlip(prob=0.5)

    def forward(self, x):
        for module in self.children():
            x = module(x)
        return x


class GTSRB(paddle.io.Dataset):
    def __init__(self, opt, train, transforms):
        super(GTSRB, self).__init__()
        if train:
            self.data_folder = os.path.join(opt.data_root, "GTSRB/Train")
            self.images, self.labels = self._get_data_train_list()
        else:
            self.data_folder = os.path.join(opt.data_root, "GTSRB/Test")
            self.images, self.labels = self._get_data_test_list()

        self.transforms = transforms

    def _get_data_train_list(self):
        images = []
        labels = []
        for c in range(0, 43):
            prefix = self.data_folder + "/" + format(c, "05d") + "/"
            gtFile = open(prefix + "GT-" + format(c, "05d") + ".csv")
            gtReader = csv.reader(gtFile, delimiter=";")
            next(gtReader)
            for row in gtReader:
                images.append(prefix + row[0])
                labels.append(int(row[7]))
            gtFile.close()
        return images, labels

    def _get_data_test_list(self):
        images = []
        labels = []
        prefix = os.path.join(self.data_folder, "GT-final_test.csv")
        gtFile = open(prefix)
        gtReader = csv.reader(gtFile, delimiter=";")
        next(gtReader)
        for row in gtReader:
            images.append(self.data_folder + "/" + row[0])
            labels.append(int(row[7]))
        return images, labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, index):
        image = Image.open(self.images[index])
        image = self.transforms(image)
        label = self.labels[index]
        return image, label

CSV = namedtuple("CSV", ["header", "index", "data"])
class CelebA(paddle.io.Dataset):
    def __init__(
            self,
            root: str,
            split: str = "train",
            target_type: Union[List[str], str] = "attr",
            transform: Optional[Callable] = None,
            target_transform: Optional[Callable] = None,
    ) -> None:
        self.list_attributes = [18, 31, 21]
        self.root = root
        self.base_folder = 'celeba'
        self.transform = transform
        self.target_transform = target_transform
        self.split = split
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
        }
        split_ = split_map[split.lower()]
        splits = self._load_csv("list_eval_partition.txt")
        identity = self._load_csv("identity_CelebA.txt")
        bbox = self._load_csv("list_bbox_celeba.txt", header=1)
        landmarks_align = self._load_csv("list_landmarks_align_celeba.txt", header=1)
        attr = self._load_csv("list_attr_celeba.txt", header=1)

        mask = (splits.data == split_).squeeze()

        self.filename = [splits.index[i] for i in paddle.squeeze(paddle.nonzero(mask))]
        self.identity = identity.data[mask]
        self.bbox = bbox.data[mask]
        self.landmarks_align = landmarks_align.data[mask]
        self.attr = attr.data[mask]
        self.attr = paddle.floor_divide(self.attr + 1, paddle.to_tensor([2], dtype=paddle.int64))
        self.attr_names = attr.header

    def _load_csv(
        self,
        filename: str,
        header: Optional[int] = None,
    ) -> CSV:
        data, indices, headers = [], [], []

        fn = partial(os.path.join, self.root, self.base_folder)
        with open(fn(filename)) as csv_file:
            data = list(csv.reader(csv_file, delimiter=' ', skipinitialspace=True))

        if header is not None:
            headers = data[header]
            data = data[header + 1:]

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, paddle.to_tensor(data_int))

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        target: Any = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError("Target type \"{}\" is not recognized.".format(t))

        if self.transform is not None:
            X = self.transform(X)

        if target:
            target = tuple(target) if len(target) > 1 else target[0]

            if self.target_transform is not None:
                target = self.target_transform(target)
        else:
            target = None
        return X, target

    def __len__(self) -> int:
        return len(self.attr)


class CelebA_attr(paddle.io.Dataset):
    def __init__(self, opt, split, transforms):
        self.dataset = CelebA(root=opt.data_root, split=split, target_type="attr")
        self.list_attributes = [18, 31, 21]
        self.transforms = transforms
        self.split = split

    def _convert_attributes(self, bool_attributes):
        return bool_attributes[0] * 4 + bool_attributes[1] * 2 + bool_attributes[2] 
        # return (bool_attributes[0] << 2) + (bool_attributes[1] << 1) + (bool_attributes[2])

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        input, target = self.dataset[index]
        input = self.transforms(input)
        bool_attributes = target[self.list_attributes]
        target =self._convert_attributes(bool_attributes)
        return (input, target)


def get_dataloader(opt, train=True, attack=False, pretensor_transform=False):
    if opt.dataset == 'cifar10':
        pretensor_transform = True
    train_str = 'train' if train else 'test'
    transform = get_transform(opt, train, attack, pretensor_transform)
    if opt.dataset == "gtsrb":
        dataset = GTSRB(opt, train, transform)
    elif opt.dataset == "mnist":
        dataset = paddle.vision.datasets.MNIST(mode=train_str, transform=transform, download=True)
    elif opt.dataset == "cifar10":
        dataset = paddle.vision.datasets.Cifar10(data_file=os.path.join(opt.data_root, 'cifar-10-python.tar.gz'), mode=train_str, transform=transform, download=True)
    elif opt.dataset == "celeba":
        if train:
            split = "train"
        else:
            split = "test"
        dataset = CelebA_attr(opt, split, transform)
    else:
        raise Exception("Invalid dataset")
    dataset = DictDatasetWrapper(dataset, field_names=['input', 'target'])
    dataloader = paddle.io.DataLoader(dataset, batch_size=opt.bs, num_workers=opt.num_workers, shuffle=True) 
    return dataloader


def get_dataset(opt, train=True):
    if opt.dataset == "gtsrb":
        dataset = GTSRB(
            opt,
            train,
            transforms=transforms.Compose([transforms.Resize((opt.input_height, opt.input_width)), ToNumpy()]),
        )
    elif opt.dataset == "mnist":
        dataset = paddle.vision.datasets.MNIST(mode=train, transform=ToNumpy())
    elif opt.dataset == "cifar10":
        dataset = paddle.vision.datasets.Cifar10(data_file=os.path.join(opt.data_root, 'cifar-10-python.tar.gz'), mode=train, transform=ToNumpy(), download=True)
    elif opt.dataset == "celeba":
        if train:
            split = "train"
        else:
            split = "test"
        dataset = CelebA_attr(
            opt,
            split,
            transforms=transforms.Compose([transforms.Resize((opt.input_height, opt.input_width)), ToNumpy()]),
        )
    else:
        raise Exception("Invalid dataset")
    return dataset



if __name__ == "__main__":
    import config

    opt = config.get_arguments().parse_args()

    opt.dataset = 'mnist'
    opt.data_root = '/root/projects/AttackDefence/data/AttackDefence/'
    if opt.dataset in ["mnist", "cifar10"]:
        opt.num_classes = 10
    elif opt.dataset == "gtsrb":
        opt.num_classes = 43
    elif opt.dataset == "celeba":
        opt.num_classes = 8
    else:
        raise Exception("Invalid Dataset")

    if opt.dataset == "cifar10":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "gtsrb":
        opt.input_height = 32
        opt.input_width = 32
        opt.input_channel = 3
    elif opt.dataset == "mnist":
        opt.input_height = 28
        opt.input_width = 28
        opt.input_channel = 1
    elif opt.dataset == "celeba":
        opt.input_height = 64
        opt.input_width = 64
        opt.input_channel = 3
    else:
        raise Exception("Invalid Dataset")

    train_dl = get_dataloader(opt, True)
    # test_dl = get_dataloader(opt, False)
    transform = PostTensorTransform(opt, p=0.5)
    for batch_idx, batch in enumerate(train_dl):
        transform(batch["input"])
    print(len(train_dl.dataset))
    # print(len(test_dl.dataset))
    pass