import os
import pickle
import random
import tarfile
import urllib.request

import numpy as np
import torch
import torch.utils.data as utils
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader, Dataset, TensorDataset

# np.random.seed(0)
# random.seed(0)
# torch.random.manual_seed(0)
# torch.manual_seed(0)
# torch.cuda.manual_seed_all(0)


def get_epsilon():
    pass


# def get_text_dataset(dataset_name):
#     train_dataset = torchtext_dict[dataset_name](root='../data', split="train")
#     test_dataset = torchtext_dict[dataset_name](root='../data', split="test")

#     x_tr = []
#     y_tr = []
#     for label, line in train_dataset:
#         x_tr += line.split()
#         y_tr.append(label)

#     x_te = []
#     y_te = []
#     for label, line in test_dataset:
#         x_te += line.split()
#         y_te.append(label)

#     return (x_tr, y_tr), (x_te, y_te)


def get_text_dataset(dataset_name):
    dataset = np.load("../data/" + dataset_name + "_bert.npz")
    x_tr, y_tr = dataset["train_X"], dataset["train_Y"]
    x_te, y_te = dataset["test_X"], dataset["test_Y"]

    if dataset_name == "yelp_full":
        x_tr = x_tr[y_tr != 2]
        y_tr = y_tr[y_tr != 2]
        x_te = x_te[y_te != 2]
        y_te = y_te[y_te != 2]

    perm = np.random.permutation(len(y_tr))
    x_tr, y_tr = x_tr[perm], y_tr[perm]

    if dataset_name == "imdb" or dataset_name == "20ng":
        x_tr = np.asarray(x_tr, dtype=np.float32)
        y_tr = np.asarray(y_tr, dtype=np.int32)
    else:
        train_idx = train_val_split(y_tr, 50000)
        x_tr = np.asarray(x_tr[train_idx], dtype=np.float32)
        y_tr = np.asarray(y_tr[train_idx], dtype=np.int32)

    x_te = np.asarray(x_te, dtype=np.float32)
    y_te = np.asarray(y_te, dtype=np.int32)

    return (x_tr, y_tr), (x_te, y_te)


def train_val_split(labels, train_size, seed=0):
    """Split the original training set into labeled training set, unlabeled training set, development set

    Arguments:
        labels {list} -- List of labeles for original training set
        rate_labeled {float} -- rate of labeled data
        rate_unlabeled {float} -- rate of unlabeled data

    Keyword Arguments:
        seed {int} -- [random seed of np.shuffle] (default: {0})

    Returns:
        [list] -- idx for labeled training set, unlabeled training set, development set
    """
    np.random.seed(seed)
    labels = np.array(labels)
    idxs = list(range(len(labels)))
    train_idxs, _ = train_test_split(idxs, train_size=train_size, stratify=labels)

    np.random.shuffle(train_idxs)

    return train_idxs


def binarize_text_polarity_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[_trainY == 0] = -1

    testY = np.ones(len(_testY), dtype=np.int32)
    testY[_testY == 0] = -1

    return trainY, testY


def binarize_yelp_full_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[(_trainY == 0) | (_trainY == 1)] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[(_testY == 0) | (_testY == 1)] = -1
    return trainY, testY


def binarize_20ng_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[(_trainY == 11) | (_trainY == 12) | (_trainY == 13) |
           (_trainY == 14) | (_trainY == 15) | (_trainY == 16) |
           (_trainY == 17) | (_trainY == 18) | (_trainY == 19)] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[(_testY == 11) | (_testY == 12) | (_testY == 13) | (_testY == 14) |
          (_testY == 15) | (_testY == 16) | (_testY == 17) | (_testY == 18) |
          (_testY == 19)] = -1
    return trainY, testY


def _3D_to_4(x):
    '''
    :param x: For mnist, it is a tensor of shape (len, 28, 28)
    :return: a tensor of shape (len, 1, 28, 28)
    '''
    return x.reshape(x.shape[0], 1, x.shape[1], x.shape[2])


def get_fashionmnist():
    train_dataset = torchvision.datasets.FashionMNIST(root='../data',
                                                      train=True,
                                                      download=True)
    test_dataset = torchvision.datasets.FashionMNIST(root='../data',
                                                     train=False,
                                                     download=True)
    x_tr = train_dataset.data
    y_tr = train_dataset.targets
    x_te = test_dataset.data
    y_te = test_dataset.targets

    x_tr = _3D_to_4(x_tr) / 255.
    x_te = _3D_to_4(x_te) / 255.

    x_tr = np.asarray(x_tr, dtype=np.float32)
    y_tr = np.asarray(y_tr, dtype=np.int32)
    x_te = np.asarray(x_te, dtype=np.float32)
    y_te = np.asarray(y_te, dtype=np.int32)

    return (x_tr, y_tr), (x_te, y_te)


def binarize_fashionmnist_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[(_trainY == 0) | (_trainY == 2) | (_trainY == 3) | (_trainY == 5) |
           (_trainY == 6) | (_trainY == 8) | (_trainY == 9)] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[(_testY == 0) | (_testY == 2) | (_testY == 3) | (_testY == 5) |
          (_testY == 6) | (_testY == 8) | (_testY == 9)] = -1
    return trainY, testY


def get_stl10():
    train_dataset = torchvision.datasets.STL10(root='../data',
                                               split="train+unlabeled",
                                               download=True)
    test_dataset = torchvision.datasets.STL10(root='../data',
                                              split="test",
                                              download=True)
    x_tr = train_dataset.data
    y_tr = train_dataset.labels
    x_te = test_dataset.data
    y_te = test_dataset.labels

    perm = np.random.permutation(len(y_tr))
    x_tr, y_tr = x_tr[perm], y_tr[perm]

    x_tr = np.asarray(x_tr[:50000], dtype=np.float32)
    y_tr = np.asarray(y_tr[:50000], dtype=np.int32)
    x_te = np.asarray(x_te, dtype=np.float32)
    y_te = np.asarray(y_te, dtype=np.int32)

    return (x_tr, y_tr), (x_te, y_te)


def binarize_stl10_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[(_trainY == -1) | (_trainY == 1) | (_trainY == 4) | (_trainY == 5) |
           (_trainY == 6) | (_trainY == 7)] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[(_testY == -1) | (_testY == 1) | (_testY == 4) | (_testY == 5) |
          (_testY == 6) | (_testY == 7)] = -1
    return trainY, testY


def get_mnist():
    mnist = fetch_openml('mnist_784', data_home="../data")
    x = mnist.data
    y = mnist.target
    # reshape to (#data, #channel, width, height)
    x = np.reshape(x, (x.shape[0], 1, 28, 28)) / 255.
    x_tr = np.asarray(x[:60000], dtype=np.float32)
    y_tr = np.asarray(y[:60000], dtype=np.int32)
    x_te = np.asarray(x[60000:], dtype=np.float32)
    y_te = np.asarray(y[60000:], dtype=np.int32)
    return (x_tr, y_tr), (x_te, y_te)


def binarize_mnist_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[_trainY % 2 == 1] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[_testY % 2 == 1] = -1
    return trainY, testY


def unpickle(file):
    fo = open(file, 'rb')
    dict = pickle.load(fo, encoding='latin1')
    fo.close()
    return dict


def conv_data2image(data):
    return np.rollaxis(data.reshape((3, 32, 32)), 0, 3)


def get_cifar10(path="../data/cifar10"):
    if not os.path.isdir(path):
        os.mkdir(path)
    url = "http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz"
    file_name = os.path.basename(url)
    full_path = os.path.join(path, file_name)
    folder = os.path.join(path, "cifar-10-batches-py")
    # if cifar-10-batches-py folder doesn't exists, download from website
    if not os.path.isdir(folder):
        print("download the dataset from {} to {}".format(url, path))
        urllib.request.urlretrieve(url, full_path)
        with tarfile.open(full_path) as f:
            f.extractall(path=path)
        urllib.request.urlcleanup()

    x_tr = np.empty((0, 32 * 32 * 3))
    y_tr = np.empty(1)
    for i in range(1, 6):
        fname = os.path.join(folder, "%s%d" % ("data_batch_", i))
        data_dict = unpickle(fname)
        if i == 1:
            x_tr = data_dict['data']
            y_tr = data_dict['labels']
        else:
            x_tr = np.vstack((x_tr, data_dict['data']))
            y_tr = np.hstack((y_tr, data_dict['labels']))

    data_dict = unpickle(os.path.join(folder, 'test_batch'))
    x_te = data_dict['data']
    y_te = np.array(data_dict['labels'])

    bm = unpickle(os.path.join(folder, 'batches.meta'))
    # label_names = bm['label_names']
    # rehape to (#data, #channel, width, height)
    x_tr = np.reshape(x_tr, (np.shape(x_tr)[0], 3, 32, 32)).astype(np.float32)
    x_te = np.reshape(x_te, (np.shape(x_te)[0], 3, 32, 32)).astype(np.float32)
    # normalize
    x_tr /= 255.
    x_te /= 255.
    return (x_tr, y_tr), (x_te, y_te)  # , label_names


def binarize_cifar10_class(_trainY, _testY):
    trainY = np.ones(len(_trainY), dtype=np.int32)
    trainY[(_trainY == 2) | (_trainY == 3) | (_trainY == 4) | (_trainY == 5) |
           (_trainY == 6) | (_trainY == 7)] = -1
    testY = np.ones(len(_testY), dtype=np.int32)
    testY[(_testY == 2) | (_testY == 3) | (_testY == 4) | (_testY == 5) |
          (_testY == 6) | (_testY == 7)] = -1
    return trainY, testY


def to_dataloader(my_x, my_y, batchsize):
    my_dataset = utils.TensorDataset(
        torch.from_numpy(my_x), torch.from_numpy(my_y))  # create your datset
    my_dataloader = utils.DataLoader(my_dataset,
                                     batch_size=batchsize,
                                     pin_memory=True)  # create your dataloader
    return my_dataloader


class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __getitem__(self, index):
        return (torch.tensor(self.data[index]),
                torch.tensor(self.labels[index]))

    def __len__(self):
        return len(self.labels)


def to_train_dataloader(my_x, my_y, batchsize):
    my_dataset = TensorDataset(torch.from_numpy(my_x),
                               torch.from_numpy(my_y))  # create your datset
    # my_dataset = MyDataset(my_x, my_y)
    # my_dataset = [(torch.tensor(my_x[i]), my_y[i]) for i in range(len(my_y))]
    my_dataloader = DataLoader(
        my_dataset,
        batch_size=batchsize,
        shuffle=True,
        drop_last=True,
    )  # create your dataloader

    return my_dataloader


def to_val_dataloader(my_x, my_y, batchsize):
    my_dataset = TensorDataset(torch.from_numpy(my_x),
                               torch.from_numpy(my_y))  # create your datset
    # my_dataset = MyDataset(my_x, my_y)
    my_dataloader = DataLoader(
        my_dataset,
        batch_size=batchsize,
        shuffle=False,
    )  # create your dataloader

    return my_dataloader


def make_dataset(dataset,
                 n_labeled,
                 n_unlabeled,
                 with_bias=False,
                 resample_model=""):
    def make_PU_dataset_from_binary_dataset(x,
                                            y,
                                            labeled=n_labeled,
                                            unlabeled=n_unlabeled,
                                            bias=with_bias,
                                            resamplemodel=resample_model):
        labels = np.unique(y)
        positive, negative = labels[1], labels[0]
        X, Y = np.asarray(x, dtype=np.float32), np.asarray(y, dtype=np.int32)
        assert (len(X) == len(Y))
        perm = np.random.permutation(len(Y))
        X, Y = X[perm], Y[perm]
        # number of positive
        n_p = (Y == positive).sum()
        # number of labeled
        n_lp = labeled
        # number of negative
        n_n = (Y == negative).sum()
        # number of unlabeled
        n_u = unlabeled
        if labeled + unlabeled == len(X):
            n_up = n_p - n_lp
        elif unlabeled == len(X):
            n_up = n_p
        else:
            raise ValueError("Only support |P|+|U|=|X| or |U|=|X|.")
        prior = float(n_up) / float(n_u)
        Xlp = X[Y == positive][:n_lp]
        Xup = np.concatenate((X[Y == positive][n_lp:], Xlp), axis=0)[:n_up]
        Xun = X[Y == negative]

        if bias:
            from resampling import resample
            Xlp = resample(Xlp, Xup, Xun, resamplemodel)

        # X = np.asarray(np.concatenate((Xlp, Xup, Xun), axis=0),
        #                dtype=np.float32)
        # print(X.shape)
        # Y = np.asarray(np.concatenate((np.ones(n_lp), -np.ones(n_u))),
        #                dtype=np.int32)
        # perm = np.random.permutation(len(Y))
        # X, Y = X[perm], Y[perm]

        Xu = np.asarray(np.concatenate((Xup, Xun), axis=0), dtype=np.float32)
        Xp = Xlp
        return Xp, Xu, prior

    def make_PNT_dataset_from_binary_dataset(x,
                                             y,
                                             xlp,
                                             prior,
                                             labeled=n_labeled):
        labels = np.unique(y)
        positive, negative = labels[1], labels[0]
        X, Y = np.asarray(x, dtype=np.float32), np.asarray(y, dtype=np.int32)
        assert (len(X) == len(Y))
        assert (len(xlp) == labeled)
        perm = np.random.permutation(len(Y))
        X, Y = X[perm], Y[perm]
        Xlp = xlp
        n_lp = labeled
        n_ln = round(((1. - prior) / (2 * prior))**2 * labeled)
        Xln = X[Y == negative][:n_ln]
        X = np.asarray(np.concatenate((Xlp, Xln), axis=0), dtype=np.float32)
        print(X.shape)
        Y = np.asarray(np.concatenate((np.ones(n_lp), -np.ones(n_ln))),
                       dtype=np.int32)
        perm = np.random.permutation(len(Y))
        X, Y = X[perm], Y[perm]
        return X, Y, prior

    def make_PN_dataset_from_binary_dataset1(x, xlp, y):
        labels = np.unique(y)
        positive, negative = labels[1], labels[0]
        X, Xlp, Y = np.asarray(x, dtype=np.float32), np.asarray(
            xlp, dtype=np.float32), np.asarray(y, dtype=np.int32)
        n_p = (Y == positive).sum()
        n_n = (Y == negative).sum()
        n_lp = len(Xlp)
        prior = float(n_p + n_lp) / float(n_p + n_lp + n_n)
        Xp = X[Y == positive][:n_p]
        Xn = X[Y == negative][:n_n]
        X = np.asarray(np.concatenate((Xp, Xlp, Xn)), dtype=np.float32)
        Y = np.asarray(np.concatenate((np.ones(n_p + n_lp), -np.ones(n_n))),
                       dtype=np.int32)
        perm = np.random.permutation(len(Y))
        X, Y = X[perm], Y[perm]
        return X, Y, prior

    def make_PN_dataset_from_binary_dataset(x, y):
        labels = np.unique(y)
        positive, negative = labels[1], labels[0]
        X, Y = np.asarray(x, dtype=np.float32), np.asarray(y, dtype=np.int32)
        n_p = (Y == positive).sum()
        n_n = (Y == negative).sum()
        prior = float(n_p) / float(n_p + n_n)
        Xp = X[Y == positive][:n_p]
        Xn = X[Y == negative][:n_n]
        X = np.asarray(np.concatenate((Xp, Xn)), dtype=np.float32)
        Y = np.asarray(np.concatenate((np.ones(n_p), -np.ones(n_n))),
                       dtype=np.int32)
        perm = np.random.permutation(len(Y))
        X, Y = X[perm], Y[perm]
        return X, Y, prior

    (_trainX, _trainY), (_testX, _testY) = dataset
    trainXp, trainXu, prior = make_PU_dataset_from_binary_dataset(
        _trainX, _trainY)
    trainX2, trainY2, prior2 = make_PN_dataset_from_binary_dataset(
        _trainX, _trainY)
    testX, testY, _ = make_PN_dataset_from_binary_dataset(_testX, _testY)
    print("training:{}\t{}".format(trainXp.shape, trainXu.shape))
    print("test:{}".format(testX.shape))

    XYPtrain = [trainXp, np.ones(len(trainXp))]
    XYUtrain = [trainXu, -np.ones(len(trainXu))]
    XYtrain2 = [trainX2, trainY2]
    XYtest = [testX, testY]

    return (XYPtrain, XYUtrain, XYtrain2, XYtest, prior, prior2,
            trainXp.size // len(trainXp))


def load_dataset(dataset_name,
                 n_labeled,
                 n_unlabeled,
                 batchsize,
                 with_bias=False,
                 resample_model=""):
    print("==================")
    print("loading data...")
    if dataset_name == "mnist":
        (trainX, trainY), (testX, testY) = get_mnist()
        trainY, testY = binarize_mnist_class(trainY, testY)
    elif dataset_name == "cifar10":
        (trainX, trainY), (testX, testY) = get_cifar10()
        trainY, testY = binarize_cifar10_class(trainY, testY)
    elif dataset_name == "fashionmnist":
        (trainX, trainY), (testX, testY) = get_fashionmnist()
        trainY, testY = binarize_fashionmnist_class(trainY, testY)
    elif dataset_name == "stl10":
        (trainX, trainY), (testX, testY) = get_stl10()
        trainY, testY = binarize_stl10_class(trainY, testY)
    elif dataset_name == "imdb" or dataset_name == "yelp" or dataset_name == "amazon":
        (trainX, trainY), (testX, testY) = get_text_dataset(dataset_name)
        trainY, testY = binarize_text_polarity_class(trainY, testY)
    elif dataset_name == "yelp_full":
        (trainX, trainY), (testX, testY) = get_text_dataset(dataset_name)
        trainY, testY = binarize_yelp_full_class(trainY, testY)
    elif dataset_name == "20ng":
        (trainX, trainY), (testX, testY) = get_text_dataset(dataset_name)
        trainY, testY = binarize_20ng_class(trainY, testY)
    else:
        raise ValueError("dataset name {} is unknown.".format(dataset_name))

    XYPtrain, XYUtrain, XYtrain2, XYtest, prior, prior2, dim = make_dataset(
        ((trainX, trainY), (testX, testY)), n_labeled, n_unlabeled, with_bias,
        resample_model)

    XYPtrainLoader = to_train_dataloader(XYPtrain[0], XYPtrain[1], 256)
    XYUtrainLoader = to_train_dataloader(XYUtrain[0], XYUtrain[1], 512)
    XYtrainLoader2 = to_train_dataloader(XYtrain2[0], XYtrain2[1], batchsize)

    XYvalidLoader = to_val_dataloader(XYtrain2[0], XYtrain2[1], 1000)

    XYtestLoader = to_val_dataloader(XYtest[0], XYtest[1], 1000)

    print("load data success!")
    print("==================")
    return (XYPtrainLoader, XYUtrainLoader, XYtrainLoader2, XYvalidLoader,
            XYtestLoader, prior, prior2, dim)
