"""dataset.py"""

import os
import PIL
import numpy as np
import csv
from collections import namedtuple

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
from PIL import Image
from torchvision.datasets.utils import download_file_from_google_drive, check_integrity, verify_str_arg, extract_archive
from typing import Any, Callable, List, Optional, Union, Tuple


CSV = namedtuple("CSV", ["header", "index", "data"])   


def is_power_of_2(num):
    return ((num & (num - 1)) == 0) and num != 0


class CustomImageFolder(ImageFolder):
    def __init__(self, root, transform=None):
        super(CustomImageFolder, self).__init__(root, transform)

    def __getitem__(self, index):
        path = self.imgs[index][0]
        img = self.loader(path)
        if self.transform is not None:
            img = self.transform(img)

        return img


class CustomTensorDataset(Dataset):
    def __init__(self, data_tensor):
        self.data_tensor = data_tensor

    def __getitem__(self, index):
        return self.data_tensor[index]

    def __len__(self):
        return self.data_tensor.size(0)


def return_data(args):
    name = args.dataset
    dset_dir = args.dset_dir
    batch_size = args.batch_size
    num_workers = args.num_workers
    image_size = args.image_size
    assert image_size == 64, 'currently only image size of 64 is supported'

    if name.lower() == '3dchairs':
        root = os.path.join(dset_dir, '3DChairs')
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),])
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder

    elif name.lower() == 'celeba':
        root = os.path.join(dset_dir, 'celeba')
        transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),])
        train_kwargs = {'root':root, 'transform':transform}
        dset = CustomImageFolder

    elif name.lower() == 'dsprites':
        root = os.path.join(dset_dir, 'dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz')
        if not os.path.exists(root):
            import subprocess
            print('Now download dsprites-dataset')
            subprocess.call(['./download_dsprites.sh'])
            print('Finished')
        data = np.load(root, encoding='bytes')
        data = torch.from_numpy(data['imgs']).unsqueeze(1).float()
        train_kwargs = {'data_tensor':data}
        dset = CustomTensorDataset

    else:
        raise NotImplementedError


    train_data = dset(**train_kwargs)
    train_loader = DataLoader(train_data,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              pin_memory=True,
                              drop_last=True)

    data_loader = train_loader

    return data_loader

class CelebA(Dataset):
    base_folder = "celeba"
    def __init__(self, root, image_size, split, attr_label, target_type = "attr"):
        super(CelebA, self).__init__()
        self.root = root
        if isinstance(target_type, list):
            self.target_type = target_type
        else:
            self.target_type = [target_type]

        if not self.target_type and self.target_transform is not None:
            raise RuntimeError("target_transform is specified but target_type is empty")
        self.split = split
        # images = np.loadtxt(attr_path, skiprows=2, usecols=[0], dtype=np.str_)
        split_map = {
            "train": 0,
            "valid": 1,
            "test": 2,
            "all": None,
        }
        split_ = split_map[verify_str_arg(split.lower(), "split", ("train", "valid", "test", "all"))]
        splits = self._load_csv("list_eval_partition.txt")  # 获得划分结果 0,1,2
        attr = self._load_csv("list_attr_celeba.txt", header=1)
        mask = slice(None) if split_ is None else (splits.data == split_).squeeze()  # 提取出需要处理的划分集
        if mask == slice(None):  # if split == "all"
            self.filename = splits.index  # 如果全要就都给
        else:
            self.filename = [splits.index[i] for i in torch.squeeze(torch.nonzero(mask))]  # 如果只要片段就提取出片段
        self.attr = attr.data[mask]
        # map from {-1, 1} to {0, 1}
        self.attr = torch.div(self.attr + 1, 2, rounding_mode="floor")
        self.attr_names = attr.header  # 第一行内容作为属性名
        self.attr_label = attr_label
        
        # if split == 'train':
        #     self.images = images[:162770]  # 因为第一行是总数
        # if split == 'valid':
        #     self.images = images[162770:182637]
        # if split == 'test':
        #     self.images = images[182637:]
        
        self.transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
        ])
                                       
        # self.length = len(self.images)
        self.length = len(self.attr)
    def _load_csv(
        self,
        filename: str,
        header: Optional[int] = None,
    ) -> csv:
        with open(os.path.join(self.root, self.base_folder, filename)) as csv_file:
            data = list(csv.reader(csv_file, delimiter=" ", skipinitialspace=True))

        if header is not None:
            headers = data[header]
            data = data[header + 1 :]
        else:
            headers = []

        indices = [row[0] for row in data]
        data = [row[1:] for row in data]
        data_int = [list(map(int, i)) for i in data]

        return CSV(headers, indices, torch.tensor(data_int))
    def __getitem__(self, index):
        # img = self.transform(Image.open(os.path.join(self.root, self.images[index])))
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))  # 到对应路径取出对应图片

        target: Any = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])  # 获得第index张图片对应的40个属性值
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError(f'Target type "{t}" is not recognized.')

        if self.transform is not None:
            X = self.transform(X)

        if target:
            target = tuple(target) if len(target) > 1 else target[0]
            
            # if self.target_transform is not None:
            #     target = self.target_transform(target)

            if self.attr_label is not None:
                main_attr_index= self.attr_names.index(self.attr_label['main_attr'])  # 获得main_attr属性的索引序号
                main_attr = target[int(main_attr_index)]  # 获得main_attr属性值
                sub_attr = []
                for i in range(len(self.attr_label['sub_attr'])):
                    sub_attr_index= self.attr_names.index(self.attr_label['sub_attr'][i])  # 获得sub_attr属性的索引序号
                    sub_attr.append(target[int(sub_attr_index)])  # 获得sub_attr的属性值 
        else:
            target = None
            main_attr = None
            sub_attr = None

        return X, main_attr, sub_attr, self.filename[index]
        # return img
    def __len__(self):
        return self.length


if __name__ == '__main__':
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),])

    dset = CustomImageFolder('data/CelebA', transform)
    loader = DataLoader(dset,
                       batch_size=32,
                       shuffle=True,
                       num_workers=1,
                       pin_memory=False,
                       drop_last=True)

    images1 = iter(loader).next()
    import ipdb; ipdb.set_trace()
