from PIL import Image
import os.path
import torch
import warnings
import torch.utils.data as data
from torchvision import transforms
import numpy as np

def load_image_path(key, out_field, d):
    out_field = Image.open(d).convert('L')
    return out_field

def convert_tensor(key, d):
    # d[key] = 1.0 - torch.from_numpy(np.array(d[key], np.float32, copy=False)).transpose(0, 1).contiguous().view(1, d[key].size[0], d[key].size[1])
    c=torch.from_numpy(np.array(d[key], np.float32, copy=False)).transpose(0, 1).contiguous().view(1, d[key].size[0], d[key].size[1])
    d=(255.0-c)/255.0
    return d

def scale_image(key, height, width, d):
    d[key] = d[key].resize((height, width))
    return d

def convert_dict(k, v):
    return { k: v }

class FAMNIST(data.Dataset):

    classes = ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four',
               '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']

    @property
    def train_labels(self):
        warnings.warn("train_labels has been renamed targets")
        return self.targets

    @property
    def test_labels(self):
        warnings.warn("test_labels has been renamed targets")
        return self.targets

    @property
    def train_data(self):
        warnings.warn("train_data has been renamed data")
        return self.data

    @property
    def test_data(self):
        warnings.warn("test_data has been renamed data")
        return self.data

    def __init__(self, root, train=True, transform=None, target_transform=None, download=False, color=False):
        self.root = root
        self.transform = transform
        self.target_transform = target_transform
        self.train = train  # training set or test set
        self.color = color

        if self.train:
            self.data, self.targets = self.generate_ds(self.root)
        else:
            self.data, self.targets = self.generate_ds_test(self.root)


    def __getitem__(self, index):
        """
        Args:
            index (int): Index
        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        if self.color:
            img = Image.open(img)
        else:
            img = Image.open(img).convert('L')

        if self.transform is not None:
            img = self.transform(img)

        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.data)

    @property
    def raw_folder(self):
        return os.path.join(self.root, self.__class__.__name__, 'raw')

    @property
    def processed_folder(self):
        return os.path.join(self.root, 'data', 'processed')

    @property
    def class_to_idx(self):
        return {_class: i for i, _class in enumerate(self.classes)}

    def generate_ds(self, root):
        data = []
        targets = []
        root_class = os.path.join(root, 'train')
        class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        class_counter = 0

        for class_name in class_names:
            img_names = os.listdir(os.path.join(root_class, class_name))
            for img_name in img_names:
                img = os.path.join(root_class, class_name, img_name)
                data.append(img)
                targets.append(class_counter)
            class_counter += 1

        return data, targets

    def generate_ds_test(self, root):

        data = []
        targets = []
        root_class = os.path.join(root, 'test')
        class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
        class_counter = 0

        for class_name in class_names:
            img_names = os.listdir(os.path.join(root_class, class_name))
            for img_name in img_names:
                img = os.path.join(root_class, class_name, img_name)
                data.append(img)
                targets.append(class_counter)
            class_counter += 1

        return data, targets