import os
import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

from PIL import Image
import PIL
import pandas as pd



NUM_CLASSES = 10


def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('L')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)


class DigClutDataset(Dataset):
    def __init__(self, image_paths, target_paths, image_size, n_letters, train_test, data_size):
        self.image_paths = image_paths
        self.target_paths = target_paths
        self.image_size = image_size
        self.n_letters = n_letters
        self.train_test = train_test
        self.data_size = data_size
        if train_test == 'train':
            #print("Sorting train image files")
            self.files_img = [image_paths + 'orig_{}.png'.format(i) for i in range(0, self.data_size)]
        elif train_test == 'test':
            #print("Sorting test image files")
            self.files_img = [image_paths + 'orig_{}.png'.format(i) for i in range(0, self.data_size)]

        targets = one_hot_targets(self.target_paths, self.train_test, self.n_letters)

        self.digit_targets = targets.joint_targets()



    def __getitem__(self, index):
        x_sample = default_loader(self.files_img[index])
        y_sample = self.digit_targets[index,:]

        transform = transforms.Compose([
            transforms.Resize((self.image_size, self.image_size), interpolation=PIL.Image.NEAREST),
            transforms.Grayscale(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,), (0.5,)),])

        x = transform(x_sample)

        x[0,:,:][x[0,:,:] == torch.median(x[0,:,:])] = 0.5

        #sample = {'x':x, 'y':y_sample}

        return (x.float(), torch.from_numpy(y_sample).float())


    def __len__(self):
        return self.data_size





class one_hot_targets():

    def __init__(self, csv_path, train_test, n_letters):
        #file = '../data/digclutX/digts/digts.csv'
        self.data = pd.read_csv(csv_path, header=None)
        self.n_letters = n_letters

        if train_test == 'test':
            self.test_data= self.data
            self.test_data_size = self.data.shape[0]
        elif train_test == 'train':
            self.train_data = self.data
            self.train_data_size = self.train_data.shape[0]

        self.train_test = train_test

        #print(self.train_test_type)
        if self.n_letters == 3:
            if train_test == 'train':
                self.digt_list_train = self.train_data.iloc[:, [5,21, 37]].values.astype(int)
                #self.x_back = self.train_data.iloc[:, 6].values.astype(float)
                #self.y_back = self.train_data.iloc[:, 7].values.astype(float)
                #self.x_mid = self.train_data.iloc[:, 22].values.astype(float)
                #self.y_mid = self.train_data.iloc[:, 23].values.astype(float)
                #self.x_front = self.train_data.iloc[:, 38].values.astype(float)
                #self.y_front = self.train_data.iloc[:, 39].values.astype(float)
            elif train_test == 'test':
                self.digt_list_test = self.test_data.iloc[:, [5,21, 37]].values.astype(int)
                #self.x_back = self.test_data.iloc[:, 6].values.astype(float)
                #self.y_back = self.test_data.iloc[:, 7].values.astype(float)
                #self.x_mid = self.test_data.iloc[:, 22].values.astype(float)
                #self.y_mid = self.test_data.iloc[:, 23].values.astype(float)
                #self.x_front = self.test_data.iloc[:, 38].values.astype(float)
                #self.y_front = self.test_data.iloc[:, 39].values.astype(float)
        elif self.n_letters == 2:
            if train_test == 'train':
                self.digt_list_train = self.train_data.iloc[:, [5,21]].values.astype(int)
                cols_train = self.train_data.iloc[:, [12, 32]].values/255
                self.cols_train = cols_train.astype(int)
                #self.x_back = self.train_data.iloc[:, 6].values.astype(float)
                #self.y_back = self.train_data.iloc[:, 7].values.astype(float)
                #self.x_front = self.train_data.iloc[:, 22].values.astype(float)
                #self.y_front = self.train_data.iloc[:, 23].values.astype(float)
            elif train_test == 'test':
                self.digt_list_test = self.test_data.iloc[:, [5,21]].values.astype(int)
                cols_test = self.test_data.iloc[:, [12, 32]].values/255
                self.cols_test = cols_test.astype(int)
                #self.x_back = self.test_data.iloc[:, 6].values.astype(float)
                #self.y_back= self.test_data.iloc[:, 7].values.astype(float)
                #self.x_front = self.test_data.iloc[:, 22].values.astype(float)
                #self.y_front = self.test_data.iloc[:, 23].values.astype(float)

        #self.distance = np.sqrt((self.x_front - self.x_back)**2 + (self.y_front - self.y_back)**2)
        #plt.hist(self.distance, bins=30)
        #plt.ylabel('Probability');
        #plt.savefig('Distances histogram.png')


    def joint_targets(self):

        if self.n_letters == 3:
            if self.train_test == 'train':
                back = torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_train[:,0]])
                mid =  torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_train[:,1]])
                front =  torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_train[:,2]])
            elif self.train_test == 'test':
                back = torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_test[:,0]])
                mid =  torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_test[:,1]])
                front =  torch.FloatTensor(np.eye(NUM_CLASSES)[self.digt_list_test[:,2]])
            joint_digts_id = back + mid + front
            joint_digts_id = np.where(joint_digts_id==2, 1, joint_digts_id)
            joint_digts_id = np.where(joint_digts_id==3, 1, joint_digts_id)
            return joint_digts_id

        if self.n_letters == 2:
            if self.train_test == 'train':
                back = self.digt_list_train[:,0]
                front =self.digt_list_train[:,1]
                #col_back = torch.LongTensor(self.cols_train[:,0]).view(self.train_data_size, -1)
            elif self.train_test =='test':
                back = self.digt_list_test[:,0]
                front =self.digt_list_test[:,1]
                #col_back = torch.LongTensor(self.cols_test[:,0]).view(self.test_data_size, -1)
            #n_values = np.max(back) + 1
            back_one_hot = torch.LongTensor(np.eye(NUM_CLASSES)[back])
            front_one_hot = torch.LongTensor(np.eye(NUM_CLASSES)[front])
            #back_front_one_hot = torch.cat((back_one_hot,front_one_hot),1)
            #col_back_front_one_hot = torch.cat((col_back, back_front_one_hot),1)
            joint_digts_id = back_one_hot + front_one_hot
            joint_digts_id = np.where(joint_digts_id==2, 1, joint_digts_id)
            return(joint_digts_id)






def return_data_digclut(dset_dir, n_letters, image_size):

    #assert image_size == 32

    train_image_paths = "{}/digts/train/".format(dset_dir)
    train_target_paths = "{}/digts/digts.csv".format(dset_dir)

    train_data_size = len(os.listdir(train_image_paths))

    """
    train_data_size = 0
    for file in os.listdir(train_image_paths):
        if file.endswith(".png"):
            train_data_size +=1
    """

    dset_train = DigClutDataset
    train_kwargs = {'image_paths': train_image_paths, 'target_paths': train_target_paths,
                    'image_size': image_size, 'n_letters': n_letters,
                    'train_test': 'train', 'data_size': train_data_size}
    train_data = dset_train(**train_kwargs)


    test_image_paths = "{}/digts/test/".format(dset_dir)
    test_target_paths = "{}/digts/digts_test.csv".format(dset_dir)

    test_data_size = len(os.listdir(test_image_paths))

    dset_test = DigClutDataset
    test_kwargs = {'image_paths': test_image_paths, 'target_paths': test_target_paths,
                   'image_size': image_size, 'n_letters': n_letters,
                   'train_test': 'test', 'data_size': test_data_size}
    test_data = dset_test(**test_kwargs)


    #print('{} train images, {} test images"'.format(train_data_size, test_data_size))

    return train_data, test_data
