from jax import random
import jax.numpy as np
from jax.config import config
#config.update("jax_enable_x64", True)


def pseudo_abs(x, mu=0.25):
    x_ = np.abs(x)
    return np.where(
        x_ > mu, x_,
        (-x**4 + 6. * x**2 * mu**2 + 3. * mu**4)
        / (8. * mu**3))


def pseudo_l1(x, mu=0.25):
    return np.sum(pseudo_abs(x))


def sigmoid(x):
    '''
    return np.where(
        x > 0.,
        1. / (1. + np.exp(-x)),
        np.exp(x) / (np.exp(x) + 1.)
    )
    '''
    return 1. / (1. + np.exp(-x))

def logistic(param, X):
    return sigmoid(np.dot(X, param))


def predict(param, X):
    p = logistic(param, X)
    return np.where(p < 0.5, 0, 1)


#def accuracy(param, X, y):
#    y_pred = predict(param, X)
#    n_correct = np.sum(y_pred == y)
#    return n_correct / len(y)


def np_log(x):
    return np.log(np.clip(a=x, a_min=1e-10, a_max=1e+10))


def log_loss(param, X, y):
    p = logistic(param, X)
    return - (np.dot(y, np_log(p)) + np.dot(1 - y, np_log(1. - p))) / len(y) 


#def elastic_net(param):
#    return (pseudo_l1(param) + np.linalg.norm(param) ** 2) / len(param)


#def train_loss(param, X, y, alpha=0.1):
#    return log_loss(param, X, y) + alpha * elastic_net(param)


def pseudo_l1_diff(x, mu=0.25):
    return np.where(
        x > mu,
        1,
        np.where(
            x < -mu,
            -1,
            (-x**3 + 3 * x * mu**2) / (2 * mu**3)
        )
    )

'''
def pseudo_abs_diff(x, mu=0.25):
    if x > mu:
        return 1
    elif x < -mu:
        return -1
    else:
        return (-x**3 + 3 * x * mu**2) / (2 * mu**3)


def pseudo_abs_diffdiff(x, mu=0.25):
    x_ = abs(x)
    if x_ > mu:
        return 0
    else:
        return (-3 * x**2 + 3 * mu**2) / (2 * mu**3)


def pseudo_l1_diff(x, mu=0.25):
    return np.array([pseudo_abs_diff(x_, mu) for x_ in x])


def pseudo_l1_diffdiff(x, mu=0.25):
    result = []
    d = len(x)
    for i in range(d):
        result_row = [0.] * d
        result_row[i] = pseudo_abs_diffdiff(x[i])
        result.append(result_row)
    return np.array(result)
'''


def generate_datasets(seed=111, size=10, X_dim=10, min_val=-10, max_val=10):
    #key_x = random.PRNGKey(seed)
    key_y = random.PRNGKey(seed)
    #X = random.uniform(key, shape=[size, X_dim], minval=-10, maxval=10)
    #x = random.uniform(key_x, shape=[size], minval=-10, maxval=10)
    x = np.linspace(min_val, max_val, size)
    X = np.array([[x_**n for n in range(X_dim)] for x_ in x])
    y = np.array([f_true(*x) for x in X])
    eps = random.normal(key_y, shape=[size]) * 5
    y += eps
    return X, y


def generate_datasets_adv(
        seed=100,
        train_size=10, test_size=100,
        d_param=4, x_min=-10, x_max=10):
    key_train = random.PRNGKey(seed)
    param_true = random.normal(key_train, [d_param])
    #print(param_true)
    X_train = random.uniform(key_train, [train_size, d_param], minval=x_min, maxval=x_max)
    y_train = np.dot(X_train, param_true) + random.normal(key_train, [train_size])
    key_test = random.PRNGKey(seed**2+1)
    X_test = random.uniform(key_test, [test_size, d_param], minval=x_min, maxval=x_max)
    y_test = np.dot(X_test, param_true) + random.normal(key_test, [test_size])
    return X_train, y_train, X_test, y_test


def generate_datasets_dda(
        seed=100,
        train_size=10, validation_size=100, test_size=100,
        d_param=4, x_min=-10, x_max=10):
    key_train = random.PRNGKey(seed)
    key_validation = random.PRNGKey(seed**2+1)
    key_test = random.PRNGKey(seed**3+2)
    param_true = random.normal(key_train, [d_param])
    #print(param_true)
    X_train = random.uniform(key_train, [train_size, d_param], minval=x_min, maxval=x_max)
    y_train = np.dot(X_train, param_true) + random.normal(key_train, [train_size])
    X_validation = random.uniform(
        key_validation, [validation_size, d_param],
        minval=x_min, maxval=x_max)
    y_validation = (
        np.dot(X_validation, param_true)
        + random.normal(key_validation, [validation_size])
    )
    X_test = random.uniform(key_test, [test_size, d_param], minval=x_min, maxval=x_max)
    y_test = np.dot(X_test, param_true) + random.normal(key_test, [test_size])
    return X_train, y_train, X_validation, y_validation, X_test, y_test