#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import json
import os
import PIL
import numpy as np
import torch
import tqdm
import time

from torch.utils.data import Dataset
import torchvision
import torchvision.transforms.functional as TF
from skimage import io, transform
from collections.abc import Iterable
from folders import folders


# data loading
class ImageFolder(Dataset):
    """ collects all images in a folder and its subfolders and returns random
    crops from them. It does not take into account any categories or subfolders.
    """

    def __init__(self, folder, im_size):
        self.folder = folder
        # generate file list
        print('finding files')
        self.files = []
        t = tqdm.tqdm(os.walk(folder), 'folders processed')
        for root, _, files in t:
            for f in files:
                if os.path.splitext(f)[1].lower() in ['.jpg', '.png', '.jpeg']:
                    self.files.append(os.path.join(root, f))
                else:
                    print('%s not accepted' % f)
                    print(os.path.splitext(f)[1].lower())
            t.set_postfix({'files': len(self.files)})
        # prepare cropping
        if isinstance(im_size, Iterable):
            self.im_size = im_size
        elif im_size:
            self.im_size = [im_size, im_size]
        else:
            self.im_size = False
        if self.im_size:
            self.crop = torchvision.transforms.RandomCrop(self.im_size)
        super().__init__()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = io.imread(self.files[idx])
        if len(image.shape) == 3:
            image = image.transpose(2, 0, 1)
        elif len(image.shape) == 2:  # 2D graysclae image
            image = np.expand_dims(image, 0).repeat(3, 0)
        else:
            print('something wrong, got image shape:')
            print(image.shape)
        image = torch.tensor(image.copy(), dtype=torch.float)
        if self.im_size:
            # image = transform.resize(image, self.im_size)
            # type_map = transform.resize(
            #     type_map.astype(np.float), self.im_size, order=0)
            try:
                pars = self.crop.get_params(image, self.im_size)
                image = TF.crop(image, *pars)
            except (ValueError, RuntimeError):
                idx_replacement = np.random.randint(len(self))
                image = self.__getitem__(idx_replacement)
        return image


class ImageFolderName(Dataset):
    """ collects all images in a folder and returns them and their names.
    """

    def __init__(self, folder):
        self.folder = folder
        # generate file list
        print('finding files')
        self.files = []
        t = tqdm.tqdm(os.walk(folder), 'folders processed')
        for root, _, files in t:
            for f in files:
                if os.path.splitext(f)[1].lower() in ['.jpg', '.png', '.jpeg']:
                    self.files.append(os.path.join(root, f))
                else:
                    print('%s not accepted' % f)
                    print(os.path.splitext(f)[1].lower())
            t.set_postfix({'files': len(self.files)})
        # prepare cropping
        super().__init__()

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        image = io.imread(self.files[idx])
        if len(image.shape) == 3:
            image = image.transpose(2, 0, 1)
        elif len(image.shape) == 2:  # 2D graysclae image
            image = np.expand_dims(image, 0).repeat(3, 0)
        else:
            print('something wrong, got image shape:')
            print(image.shape)
        image = torch.tensor(image.copy(), dtype=torch.float)
        return self.files[idx], image
