import copy
import csv
import os
import random
from collections import namedtuple
from typing import Any, Callable, List, Optional, Tuple, Union

import PIL
import numpy as np
import torch

from torchvision.datasets import CelebA
from torchvision.transforms import Compose

CSV = namedtuple("CSV", ["header", "index", "data"])


class CelebAOwn(CelebA):
    def __init__(self, root: str,
        split: str = "train",
        target_type: Union[List[str], str] = "attr",
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        download: bool = False, size=1000):
        super(CelebAOwn, self).__init__(root, split, target_type, transform, target_transform, download)
        self.filename = self.filename[:size]
        self.attr = self.attr[:size, :]

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        X = PIL.Image.open(os.path.join(self.root, self.base_folder, "img_align_celeba", self.filename[index]))

        target: Any = []
        for t in self.target_type:
            if t == "attr":
                target.append(self.attr[index, :])
            elif t == "identity":
                target.append(self.identity[index, 0])
            elif t == "bbox":
                target.append(self.bbox[index, :])
            elif t == "landmarks":
                target.append(self.landmarks_align[index, :])
            else:
                # TODO: refactor with utils.verify_str_arg
                raise ValueError(f'Target type "{t}" is not recognized.')

        chosen_attr_idx = np.array([18, 21, 31])
        attr_list = target[0][chosen_attr_idx]
        target = sum(val.item() * (2 ** idx) for idx, val in enumerate(reversed(attr_list)))

        if self.transform is not None:
            X = self.transform(X)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return X, target