#! -*- coding: utf-8
import os.path as path
import typing
# from concurrent.futures import ProcessPoolExecutor
from glob import glob

import numpy as np
import torch
from PIL import Image

CLASSES = ["30", "31", "32", "33", "34", "35", "36", "37", "38", "39",
           "41", "42", "43", "44", "45", "46", "47", "48", "49", "4a", "4b", "4c", "4d", "4e", "4f",
           "50", "51", "52", "53", "54", "55", "56", "57", "58", "59", "5a",
           "61", "62", "63", "64", "65", "66", "67", "68", "69", "6a", "6b", "6c", "6d", "6e", "6f",
           "70", "71", "72", "73", "74", "75", "76", "77", "78", "79", "7a",]
WRITERS = ["hsf_0", "hsf_1", "hsf_2", "hsf_3",
           "hsf_4", "hsf_5", "hsf_6", "hsf_7", ]


__all__ = ["FEMNIST"]


class FEMNIST(torch.utils.data.Dataset):
    def __init__(self, datadir: str, train: bool = True,
                 resize: typing.Tuple[int, int] = None, normalize: typing.Tuple[float, float] = None,
                 channel_first: bool = True):
        self.targets, self.imgfiles = self.filelists(datadir, train=train)
        self.datas = [None] * len(self.targets)
        self.resize = resize
        self.normalize = normalize
        self.channel_first = channel_first

    def filelists(self, datadir: str, train: bool = True) -> typing.Tuple[torch.Tensor, typing.List[str]]:
        targets, files = [], []
        for label, cls in enumerate(CLASSES):
            imgfiles = sorted(glob(path.join(datadir, "by_class", cls,
                                             "train_*" if train else "hsf_*", "*.png")))
            targets.extend([label] * len(imgfiles))
            files.extend(imgfiles)

        targets = torch.tensor(targets)
        return targets, files

    def __len__(self): return len(self.targets)

    def __getitem__(self, idx: int) -> typing.Tuple[torch.Tensor,
                                                    torch.Tensor]:
        y = self.targets[idx]
        x = self.datas[idx]
        if x is None:
            img = Image.open(self.imgfiles[idx])
            if self.resize is not None:
                img = img.resize(self.resize) # w, h, channel
            x = np.array(img).astype(float) / 255
            if self.channel_first: # convert channel, w, h
                x = x.transpose(2, 0, 1)
            del img  # release file pointer
            if self.normalize is not None:
                mean, var = self.normalize
                x = (x-mean)/var
            x = torch.tensor(x)
            self.datas[idx] = x

        return x, y
