import torch
import numpy as np
from torch.nn.functional import relu
from torch.utils.data import Dataset
from sklearn import datasets


def add_one(xold):
    """Add an all one column at the beginning of a matrix."""
    if type(xold) != torch.Tensor:
        xold = torch.tensor(xold)
    x = torch.cat((torch.ones(xold.shape[0], 1), xold), dim=1)
    return x


def gen_friedman1(n, d, sig):
    """Generate the Friedman dataset.
        .math::
        y(X) = 10 * sin(pi * X[:, 0] * X[:, 1]) + 20 * (X[:, 2] - 0.5) ** 2 + 10 * X[:, 3] + 5 * X[:, 4]
        + noise * N(0, 1).

        Parameters
        ----------
        n : int
            Number of sample size.
        d : int
            Number of total features.
        sig : float
            Noise level.

        Returns
        -------
        X : 2D tensor
            n by p matrix.
        y : 1D tensor
            Response.
        mu : 1D tensor
            f(X).
    """
    if d < 5:
        raise Exception('d not large enough in gen_friedman1()!')
    X, mu = datasets.make_friedman1(n_samples=n, n_features=d, noise=0)
    X, mu = torch.from_numpy(X).type(torch.float64), torch.from_numpy(mu).type(torch.float64)
    y = mu + sig * torch.normal(0, 1, size=mu.shape)
    return X, y, mu


def gen_paras(p, r, eff_p):
    """Generate the parameters w,a for the class
        .math::
        f(x)=\sum a_i relu(w_i^Tx).

        Parameters
        ----------
        p : int
            Number of total features.
        r : int
            Number of hidden layers.
        eff_p : int
            Number of useful features.

        Returns
        -------
        w : 2D tensor
            p by r matrix, the first eff_p rows are uniform generated, all others are zeros.
        a : 1D tensor
            Each element is either 1 or -1.
    """
    w = torch.rand(size=(eff_p, r))
    #     w[torch.FloatTensor(eff_p, r).uniform_() > 0.5] = 0  # randomly set half of them to 0
    w = torch.cat((w, torch.zeros(p - eff_p, r)))
    a = torch.from_numpy(np.random.choice([-1, 1], size=r)).type(torch.float64)
    return w, a


def gen_nn(n, p, sig, w, a):
    """Generate the data for the class
        .math::
        f(x)=\sum a_i relu(w_i^Tx), y=f(x)+\sigma \epsilon.

        Parameters
        ----------
        n : int
            Number of sample size.
        p : int
            Number of total features.
        sig : float
            Noise level.
        w : 2D tensor
            p by r matrix, the first eff_p rows are uniform generated, all others are zeros.
        a : 1D tensor
            Each element is either 1 or -1.

        Returns
        -------
        X : 2D tensor
            n by p matrix, element-wise standard normal.
        y : 1D tensor
            Response.
        mu : 1D tensor
            f(X).
    """
    X = torch.normal(0, 1, size=(n, p))
    mu = relu(X @ w) @ a
    y = mu + torch.normal(0, sig, size=mu.shape)
    return X, y, mu


def gen_quad(n, p, sig):
    """Generate data from quadratic function $f(x)=10(x_1-0.5)^2, y=f(x)+\sigma \epsilon$.

        Parameters
        ----------
        n : int
            Number of sample size.
        p : int
            Number of total features.
        sig : float
            Noise level.

        Returns
        -------
        X : 2D tensor
            n by p matrix, element-wise standard normal except the first column, which is uniform distributed.
        y : 1D tensor
            Response.
        mu : 1D tensor
            f(X).
    """
    X = torch.normal(0, 1, size=(n, p))
    X[:, 0] = torch.rand(n)
    mu = 10 * (X[:, 0] - 0.5) ** 2
    y = mu + sig * torch.normal(0, 1, size=mu.shape)
    return X, y, mu


def gen_linear0(n, sig):
    """
        Generate data from linear func $f(x)=x^T\beta, y=f(x)+\sigma \epsilon$, with $\beta$ and covariance matrix same
        as in Zou, 2006.

        Parameters
        ----------
        n : int
            Number of sample size.
        sig : float
            Noise level.

        Returns
        -------
        X : 2D tensor
            n by p matrix, element-wise standard normal except the first column, which is uniform distributed.
        y : 1D tensor
            Response.
        mu : 1D tensor
            f(X).
    """
    w = torch.tensor([5.6, 5.6, 5.6, 0])
    #     rho1, rho2 = -0.39, 0.23
    #     cov = np.zeros((4,4))
    #     cov[:3,:3] = (1-rho1)*np.identity(3)+rho1*np.ones((3,3))
    #     cov[3,:] = rho2
    #     cov[:,3] = rho2
    #     cov[3,3] = 1
    cov = np.array([[1., -0.39, -0.39, 0.23],
                    [-0.39, 1., -0.39, 0.23],
                    [-0.39, -0.39, 1., 0.23],
                    [0.23, 0.23, 0.23, 1.]])
    X = np.random.multivariate_normal([0] * 4, cov, n)
    X = torch.from_numpy(X)
    mu = X @ w
    y = mu + sig * torch.normal(0, 1, size=mu.shape)
    return X, y, mu


def gen_linear1(n, sig):
    """
        Generate data from linear func $f(x)=x^T\beta, y=f(x)+\sigma \epsilon$, with $\beta$ and covariance matrix same
        as in Zou, 2006.

        Parameters
        ----------
        n : int
            Number of sample size.
        sig : float
            Noise level.

        Returns
        -------
        X : 2D tensor
            n by p matrix, element-wise standard normal except the first column, which is uniform distributed.
        y : 1D tensor
            Response.
        mu : 1D tensor
            f(X).
    """
    w = torch.tensor([3, 1.5, 0, 0, 2, 0, 0, 0])
    p = len(w)
    cov = np.array([[np.power(0.5, np.abs(i - j)) for j in range(p)] for i in range(p)])
    X = np.random.multivariate_normal([0] * p, cov, n)
    X = torch.from_numpy(X)
    mu = X @ w
    y = mu + sig * torch.normal(0, 1, size=mu.shape)
    return X, y, mu


class SimData(Dataset):
    """Torch dataset for data loader."""
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[idx, :], self.y[idx]


def gen_data_loader(n_train, n_test, p, sig, w=None, a=None, batch_size=None, gen='gen_nn', validation=False):
    '''
        The design of test loader in this way (instead of a single dataset)
        is to accommodate large test data such as images
    '''
    n = n_train + n_test
    if gen == 'gen_nn':
        X, y, mu = gen_nn(n, p, sig, w, a)
    elif gen == 'gen_friedman1':
        X, y, mu = gen_friedman1(n, p, sig)
    elif gen == 'gen_quad':
        X, y, mu = gen_quad(n, p, sig)
    else:
        X, y, mu = gen(n, sig)

    if gen != 'gen_nn':
        X = add_one(X)
    X_train, y_train = X[:n_train, :], y[:n_train]
    X_test, y_test = X[n_train:, :], y[n_train:]
    train_set = SimData(X_train, y_train)
    test_set = SimData(X_test, y_test) if validation else SimData(X_test, mu[n_train:])
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
    snr = torch.mean(mu ** 2) / sig ** 2
    return train_loader, test_loader, snr


def gen_data_loader_from_data(X_train, y_train, X_test, y_test, batch_size=None, intercept=True):
    if intercept:
        X_train = add_one(X_train)
        X_test = add_one(X_test)
    train_set = SimData(X_train, y_train)
    test_set = SimData(X_test, y_test)
    train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=0)
    test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=0)
    return train_loader, test_loader, 0


if __name__ == '__main__':
    pass
