import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn import preprocessing
import sklearn.datasets
import urllib.request
import os
from datasets_directory.utils_datasets.non_private_optimizer import newton
from datasets_directory.adult import adult_download_preprocess
from datasets_directory.covertype import covertype_download_preprocess
import torch
import torchvision
from torchvision import datasets, transforms
import tarfile
import requests

path_prefix = './datasets_directory'

class Mydatasets:
    """Represents datasets we use for expriments
    """
    def __init__(self):
        data_dir = path_prefix + '/data'
        cache_dir = path_prefix + '/cache_datasets'
        if not os.path.exists(data_dir):
            os.mkdir(data_dir)
        if not os.path.exists(cache_dir):
            os.mkdir(cache_dir)


    def find_optimal_classifier(self, dataset, bias=True):
        """find the optimal weight vector for the logistic regression
            for the problems with real datasets.

        dataset = training dataset
        bias = bias for the logistic model
        """
        X, y = dataset
        reg = 1e-9
        if bias == True:
            model_lr = LogisticRegression(max_iter=200, C=1 / reg).fit(X, y)
            w_opt1 = np.concatenate([model_lr.intercept_, np.squeeze(model_lr.coef_)])
            w_opt = newton(dataset, w_opt1, bias)
        else:
            model_lr = LogisticRegression(max_iter=200, fit_intercept=False, C=1 / reg).fit(X, y)
            w_opt1 = np.squeeze(model_lr.coef_)
            w_opt = newton(dataset, w_opt1, bias)

        return w_opt

    def adult_dataset(self):
        """adult dataset
        """
        print(os.getcwd())
        data_path_x = path_prefix+'/data/adult_processed_x.npy'
        data_path_y = path_prefix+'/data/adult_processed_y.npy'
        if not os.path.exists(data_path_x) or not os.path.exists(data_path_y):
            print("downloading Adult dataset ....")
            adult_download_preprocess()
        X = np.load(data_path_x)
        labels = np.load(data_path_y)
        dataset = X, labels
        w_opt = self.find_optimal_classifier(dataset)
        X = np.hstack((np.ones(shape=(np.shape(X)[0], 1)), X))  # adding a dummy dimension for the bias term.
        return X, labels, w_opt

    def fmnist_dataset(self):
        """
          fmnist dataset
        """
        transform_data = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5), (0.5))])

        train_data_T = datasets.FashionMNIST(root=path_prefix+'/data', download=True,  # set True to download the data
                                             train=True, transform=transform_data)

        train_loader = torch.utils.data.DataLoader(train_data_T, batch_size=len(train_data_T))

        x_train = next(iter(train_loader))[0].numpy()
        x_train = x_train.reshape(len(x_train), -1)
        y_train = next(iter(train_loader))[1].numpy()
        label0 = 0
        label1 = 3
        indx0 = np.nonzero(y_train == label0)[0]
        indx1 = np.nonzero(y_train == label1)[0]
        labels = y_train.copy()
        labels[indx0] = -1
        labels[indx1] = 1
        indx = np.concatenate((indx0, indx1))
        x_train = x_train[indx]
        labels = labels[indx]
        dataset = x_train, labels
        w_opt = self.find_optimal_classifier(dataset, bias=False)
        return x_train, labels, w_opt

    def a1a_dataset(self):
        """
          a1a dataset
        """
        a1a_url = "https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/binary/a1a.t"
        data_path = path_prefix + '/data/a1a'
        if not os.path.exists(data_path):
            f = urllib.request.urlretrieve(a1a_url, data_path)
        X, labels = sklearn.datasets.load_svmlight_file(data_path)
        X = X.toarray()
        scaler = preprocessing.StandardScaler().fit(X)
        X = scaler.transform(X)
        labels = labels.astype(float)
        dataset = X, labels
        w_opt = self.find_optimal_classifier(dataset)
        X = np.hstack((np.ones(shape=(np.shape(X)[0], 1)), X))  # adding a dummy dimension for the bias term.
        return X, labels, w_opt

    def covertype_dataset(self):
        """
          covertype dataset
        """
        data_path_x = path_prefix +'/data/covertype_binary_processed_x.npy'
        data_path_y = path_prefix +'/data/covertype_binary_processed_y.npy'
        if not os.path.exists(data_path_x) or not os.path.exists(data_path_y):
            covertype_download_preprocess()
        X = np.load(data_path_x)
        labels = np.load(data_path_y)
        dataset = X, labels
        w_opt = self.find_optimal_classifier(dataset)
        X = np.hstack((np.ones(shape=(np.shape(X)[0], 1)), X))  # adding a dummy dimension for the bias term.
        return X, labels, w_opt

    def protein_dataset(self):
        """
          protein dataset
        """
        path_protein = path_prefix + '/data/protein/'
        if not os.path.exists(path_protein):
            os.mkdir(path_protein)
            protein_url = 'https://kdd.org/cupfiles/KDDCupData/2004/data_kddcup04.tar.gz'
            protein_file = path_prefix +'/data/protein/data_kddcup04.tar.gz'
            response = requests.get(protein_url, stream=True)
            if response.status_code == 200:
                with open(protein_file, 'wb') as f:
                    f.write(response.raw.read())
            tar = tarfile.open(protein_file, "r:gz")
            tar.extractall(path_protein)
            tar.close()
        x_train = np.loadtxt(path_prefix + '/data/protein/bio_train.dat')[:, 3:]
        y_train = np.loadtxt(path_prefix + '/data/protein/bio_train.dat')[:, 2]
        num_samples_full = len(x_train)
        label0 = 0
        label1 = 1
        indx0 = np.nonzero(y_train == label0)[0]
        indx1 = np.nonzero(y_train == label1)[0]
        labels = y_train.copy()
        labels[indx0] = -1
        labels[indx1] = 1
        indx = np.arange(num_samples_full)
        np.random.seed(3000)
        indx_sample = np.random.choice(indx, 50000, replace=False)
        np.random.seed(None)
        x_train = x_train[indx_sample]
        labels = labels[indx_sample]
        feature_mean = np.mean(x_train, axis=0)
        feature_std = np.std(x_train, axis=0)
        x_train = (x_train - feature_mean) / feature_std
        dataset = x_train, labels
        w_opt = self.find_optimal_classifier(dataset)
        x_train = np.hstack((np.ones(shape=(np.shape(x_train)[0], 1)), x_train))
        return x_train, labels, w_opt

    def synthetic_dataset(self, n=10000, d=100, cov=None, w=None):
        """Generates a synthetic dataset for logistic regression.

        n = number of samples
        d = dimension
        w = true coefficient vector (optional, default = first standard basis vector)
        cov = covariance of the data (optional, default = identity)

        Features are unit vectors (by default uniformly random).
        Labels are sampled from logistic distribution, so w is the "true" solution.
        """
        mean = np.zeros(d)
        if cov is None:
            cov = np.eye(d)
        X_un = np.random.multivariate_normal(mean, cov, n)
        nrm = np.linalg.norm(X_un, axis=1)
        X = X_un * 1 / nrm[:, None]
        if w is None:
            w = np.ones(d)
            w[0] = 1
        inner_prod = np.dot(X, w)
        params = np.exp(inner_prod) / (1 + np.exp(inner_prod))
        labels = 2 * np.random.binomial(1, params) - 1
        dataset = X, labels
        w_opt = self.find_optimal_classifier(dataset, bias=False)
        return X, labels, w_opt