# from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from sklearn.datasets import load_svmlight_file
import numpy as np
import torch
import scipy.io
import os
import struct
from array import array as pyarray
from sklearn.preprocessing import MaxAbsScaler, MinMaxScaler, Normalizer, StandardScaler, RobustScaler
from sklearn.model_selection import train_test_split
import pickle

class MyDataset(Dataset):

    def __init__(self,x,y, transform=None, target_transform=None):
        data = []
        for i in range(len(x)):
            data.append((x[i],y[i]))
        self.data = data
        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):

        feature, label = self.data[index]
        return feature, label

    def __len__(self):
        return len(self.data)

#### for MNIST
def load_mnist(dataset="training", digits=np.arange(10), path=".", size=60000):
    if dataset == "training":
        fname_img = os.path.join(path, 'train-images-idx3-ubyte')
        fname_lbl = os.path.join(path, 'train-labels-idx1-ubyte')
    elif dataset == "testing":
        fname_img = os.path.join(path, 't10k-images-idx3-ubyte')
        fname_lbl = os.path.join(path, 't10k-labels-idx1-ubyte')

    else:
        raise ValueError("dataset must be 'testing' or 'training'")

    flbl = open(fname_lbl, 'rb')
    magic_nr, size = struct.unpack(">II", flbl.read(8))
    lbl = pyarray("b", flbl.read())
    flbl.close()

    fimg = open(fname_img, 'rb')
    magic_nr, size, rows, cols = struct.unpack(">IIII", fimg.read(16))
    img = pyarray("B", fimg.read())
    fimg.close()

    ind = [k for k in range(size) if lbl[k] in digits]
    N = size  # int(len(ind) * size/100.)
    images = np.zeros((N, rows, cols), dtype=np.uint8)
    labels = np.zeros((N, 1), dtype=np.int8)
    for i in range(N):  # int(len(ind) * size/100.)):
        images[i] = np.array(img[ind[i] * rows * cols: (ind[i] + 1) * rows * cols]) \
            .reshape((rows, cols))
        labels[i] = lbl[ind[i]]
    labels = [label[0] for label in labels]
    labels = np.asarray(labels)
    labels[labels <= 4] = 1
    labels[labels >= 5] = -1
    return images, labels

#### for CIFAR
def unpickle(file):
    """load the cifar-10 data"""

    with open(file, 'rb') as fo:
        data = pickle.load(fo, encoding='bytes')
    return data

def load_cifar_10_data(data_dir, negatives=False):
    """
    Return train_data, train_filenames, train_labels, test_data, test_filenames, test_labels
    """

    # get the meta_data_dict
    # num_cases_per_batch: 1000
    # label_names: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
    # num_vis: :3072

    meta_data_dict = unpickle(data_dir + "/batches.meta")
    cifar_label_names = meta_data_dict[b'label_names']
    cifar_label_names = np.array(cifar_label_names)

    # training data
    cifar_train_data = None
    cifar_train_filenames = []
    cifar_train_labels = []

    # cifar_train_data_dict
    # 'batch_label': 'training batch 5 of 5'
    # 'data': ndarray
    # 'filenames': list
    # 'labels': list

    for i in range(1, 6):
        cifar_train_data_dict = unpickle(data_dir + "/data_batch_{}".format(i))
        if i == 1:
            cifar_train_data = cifar_train_data_dict[b'data']
        else:
            cifar_train_data = np.vstack((cifar_train_data, cifar_train_data_dict[b'data']))
        cifar_train_filenames += cifar_train_data_dict[b'filenames']
        cifar_train_labels += cifar_train_data_dict[b'labels']

    cifar_train_data = cifar_train_data.reshape((len(cifar_train_data), 3, 32, 32))
    if negatives:
        cifar_train_data = cifar_train_data.transpose(0, 2, 3, 1).astype(np.float32)
    else:
        cifar_train_data = np.rollaxis(cifar_train_data, 1, 4)
    cifar_train_filenames = np.array(cifar_train_filenames)
    cifar_train_labels = np.array(cifar_train_labels)

    # test data
    # cifar_test_data_dict
    # 'batch_label': 'testing batch 1 of 1'
    # 'data': ndarray
    # 'filenames': list
    # 'labels': list

    cifar_test_data_dict = unpickle(data_dir + "/test_batch")
    cifar_test_data = cifar_test_data_dict[b'data']
    cifar_test_filenames = cifar_test_data_dict[b'filenames']
    cifar_test_labels = cifar_test_data_dict[b'labels']

    cifar_test_data = cifar_test_data.reshape((len(cifar_test_data), 3, 32, 32))
    if negatives:
        cifar_test_data = cifar_test_data.transpose(0, 2, 3, 1).astype(np.float32)
    else:
        cifar_test_data = np.rollaxis(cifar_test_data, 1, 4)
    cifar_test_filenames = np.array(cifar_test_filenames)
    cifar_test_labels = np.array(cifar_test_labels)

    return cifar_train_data, cifar_train_filenames, cifar_train_labels, \
        cifar_test_data, cifar_test_filenames, cifar_test_labels, cifar_label_names



###### MNIST
def train_test_sets_generation(dataset, Toydata_path):
    n_pos = 0
    n_neg = 0
    ##### normalize [0,1]
    if dataset == 'mnist_binary':
        X_train, y_train = load_mnist(dataset="training", digits=np.arange(10), path=Toydata_path)
        X_test, y_test = load_mnist(dataset="testing", digits=np.arange(10), path=Toydata_path)
        # X_train, X_test, y_train, y_test = load_mnist_svm(dataset1="mnist", dataset2="mnist.t", path=Toydata_path)
        # X_train, y_train = load_mnist_onedim(dataset="training", digits=np.arange(10), path=Toydata_path)
        X_train = X_train /255.0
        X_test = X_test / 255.0
        means = X_train.mean(axis=0)
        std = np.std(X_train)
        X_train = (X_train - means) / std
        X_test = (X_test - means) / std
        n_pos = sum(y_train == 1.)
        n_neg = len(y_train) - n_pos
        # X_test, y_test = load_mnist_onedim(dataset="testing", digits=np.arange(10), path=Toydata_path)
        # X_total = np.concatenate((X_train, X_test), axis=0)
        ### normallize to mean, std
        # scaler = StandardScaler(with_mean=False).fit(X_total)
        # X_train = scaler.transform(X_train)
        # X_test = scaler.transform(X_test)

    transform_train = None
    # transform_train = transforms.Compose([
    #     transforms.Normalize((0.1307), (0.3081)),
    #     transforms.ToTensor()
    # ])

    trainset = MyDataset(X_train, y_train, transform=transform_train)
    testset = MyDataset(X_test, y_test, transform=transform_train)
    # trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
    # testloader = torch.utils.data.DataLoader(testset, batch_size=8, shuffle=True)
    return trainset, testset, n_pos, n_neg

###### Diabetes
def train_test_sets_generation_diabetes(dataset, Toydata_path, train_ratio):
    n_pos = 0
    n_neg = 0

    X, y = load_svmlight_file(os.path.join(Toydata_path, dataset))
    X = X.toarray()
    y= np.int8(y)

    ### normallize to (0,1)
    scaler1 = MinMaxScaler(feature_range=(0, 0.1)).fit(X)
    X = scaler1.transform(X)

    ### normallize to mean, std
    # scaler2 = StandardScaler(with_mean=False).fit(X)
    # scaler2 = StandardScaler().fit(X)
    # X = scaler2.transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 - train_ratio, random_state=1)
    n_pos = sum(y_train == 1.)
    n_neg = len(y_train) - n_pos

    transform_train = None
    trainset = MyDataset(X_train, y_train, transform=transform_train)
    testset = MyDataset(X_test, y_test, transform=transform_train)

    return trainset, testset, n_pos, n_neg

###### ijcnn1
def train_test_sets_generation_ijcnn1(dataset, Toydata_path, train_ratio):
    n_pos = 0
    n_neg = 0

    X, y = load_svmlight_file(os.path.join(Toydata_path, dataset))
    X = X.toarray()
    y= np.int8(y)

    ### normallize to (0,1)
    scaler1 = MinMaxScaler(feature_range=(0, 1)).fit(X)
    X = scaler1.transform(X)

    ### normallize to mean, std
    # scaler2 = StandardScaler(with_mean=False).fit(X)
    # scaler2 = StandardScaler().fit(X)
    # X = scaler2.transform(X)

    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=1 - train_ratio, random_state=1)
    n_pos = sum(y_train == 1.)
    n_neg = len(y_train) - n_pos

    transform_train = None
    trainset = MyDataset(X_train, y_train, transform=transform_train)
    testset = MyDataset(X_test, y_test, transform=transform_train)

    return trainset, testset, n_pos, n_neg

###### CIFAR-10
def train_test_sets_generation_cifar10(dataset, Toydata_path):
    n_pos = 0
    n_neg = 0

    X_train, train_filenames, y_train, X_test, test_filenames, y_test, label_names = \
        load_cifar_10_data(os.path.join(Toydata_path, 'cifar-10-batches-py'))

    y_train[y_train <= 4] = 1
    y_train[y_train >= 5] = -1
    y_test[y_test <= 4] = 1
    y_test[y_test >= 5] = -1
    y_train = np.int8(y_train)
    y_test = np.int8(y_test)

    X_train = X_train / 255.0
    X_test = X_test / 255.0
    # means = X_train.mean(axis=0)
    # std = np.std(X_train)
    # X_train = (X_train - means) / std
    # X_test = (X_test - means) / std
    n_pos = sum(y_train == 1.)
    n_neg = len(y_train) - n_pos


    transform_train = None
    trainset = MyDataset(X_train, y_train, transform=transform_train)
    testset = MyDataset(X_test, y_test, transform=transform_train)

    return trainset, testset, n_pos, n_neg

###### fashionmnist
def train_test_sets_generation_fashionmnist(dataset, Toydata_path):
    ##### normalize [0,1]
    X_train, y_train = load_mnist(dataset="training", digits=np.arange(10), path=Toydata_path)
    X_test, y_test = load_mnist(dataset="testing", digits=np.arange(10), path=Toydata_path)
    # X_train, X_test, y_train, y_test = load_mnist_svm(dataset1="mnist", dataset2="mnist.t", path=Toydata_path)
    # X_train, y_train = load_mnist_onedim(dataset="training", digits=np.arange(10), path=Toydata_path)
    X_train = X_train /255.0
    X_test = X_test / 255.0
    means = X_train.mean(axis=0)
    std = np.std(X_train)
    X_train = (X_train - means) / std
    X_test = (X_test - means) / std
    n_pos = sum(y_train == 1.)
    n_neg = len(y_train) - n_pos

    transform_train = None

    trainset = MyDataset(X_train, y_train, transform=transform_train)
    testset = MyDataset(X_test, y_test, transform=transform_train)
    return trainset, testset, n_pos, n_neg


