import jax
import jax.numpy as jnp


def gen_l1_data(
    dimension=50,
    sparsity_level=0.1,
    noise_level=1e-0,
    n_samples_train=100,
    n_samples_test=100,
    seed=0,
):
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, 8)
    A = jax.random.normal(subkeys[0], (dimension, dimension))
    x_train = jax.random.normal(subkeys[1], (n_samples_train, dimension,))
    x_test = jax.random.normal(subkeys[2], (n_samples_test, dimension,))
    # impose sparsity on x_train and x_test for each sample
    for i in range(n_samples_train):
        null_indices = jax.random.choice(subkeys[3], dimension, (int((1-sparsity_level) * dimension),), replace=False)
        x_train = x_train.at[i, null_indices].set(0.)
    for i in range(n_samples_test):
        null_indices = jax.random.choice(subkeys[4], dimension, (int((1-sparsity_level) * dimension),), replace=False)
        x_test = x_test.at[i, null_indices].set(0.)

    y_train = jnp.dot(A, x_train.T).T + noise_level * jax.random.normal(subkeys[5], (n_samples_train, dimension,))
    y_test = jnp.dot(A, x_test.T).T + noise_level * jax.random.normal(subkeys[6], (n_samples_test, dimension,))
    train_data = (x_train, A, y_train)
    test_data = (x_test, A, y_test)
    return train_data, test_data


def gen_l2_data(
    sigma_data=1.0,
    **kwargs,
):
    return gen_l2_data_diag_correl(sigma_data_range=(sigma_data, sigma_data), **kwargs)


def gen_l2_data_diag_correl(
    dimension=50,
    dimension_y=50,
    sigma_data_range=(0, 10),
    noise_level=1e-3,
    n_samples_train=100,
    n_samples_test=100,
    seed=0,
):
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, 6)
    A = jax.random.normal(subkeys[0], (dimension_y, dimension))
    sigma_data = jax.random.uniform(
        subkeys[1],
        (dimension,),
        minval=sigma_data_range[0],
        maxval=sigma_data_range[1]
    )[None]
    x_train = sigma_data * jax.random.normal(subkeys[2], (n_samples_train, dimension,))
    x_test = sigma_data * jax.random.normal(subkeys[2], (n_samples_test, dimension,))

    y_train = jnp.dot(A, x_train.T).T + noise_level * jax.random.normal(subkeys[3], (n_samples_train, dimension_y,))
    y_test = jnp.dot(A, x_test.T).T + noise_level * jax.random.normal(subkeys[4], (n_samples_test, dimension_y,))
    train_data = (x_train, A, y_train)
    test_data = (x_test, A, y_test)
    return train_data, test_data


def gen_quadratic_data(
    dimension_x=10,
    dimension_latent=10,
    dimension_theta=10,
    make_x_in_range=False,
    range_budget=None,
    noise_x_sigma=0.1,
    return_theta_0=False,
    return_sqrt_H=False,
    seed=0,
):
    # we want to generate x, H, B, c
    # x of dimension (dimension_x,)
    # H positive semi-definite of dimension (dimension_x, dimension_x),
    # H = sqrt_H^T sqrt_H
    # with sqrt sqrt_H of dimension (dimension_latent, dimension_x)
    # c = sqrt_H^T kappa, where kappa is of dimension (dimension_latent,)
    # B = sqrt_H^T U where U is of dimension (dimension_latent, dimension_theta)
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, 5)
    sqrt_H = jax.random.normal(subkeys[1], (dimension_latent, dimension_x))
    H = sqrt_H.T @ sqrt_H
    kappa = jax.random.normal(subkeys[2], (dimension_latent,))
    c = sqrt_H.T @ kappa
    U = jax.random.normal(subkeys[3], (dimension_latent, dimension_theta))
    B = sqrt_H.T @ U
    if make_x_in_range:
        # x should now be written as a potential solution
        # of the inner problem at convergence
        # i.e. x = (e^{-tH} - I) H^{-1} (B theta_0 + c) + epsilon
        theta_0 = jax.random.normal(subkeys[0], (dimension_theta,))
        Btheta_c = B @ theta_0 + c
        epsilon = noise_x_sigma * jax.random.normal(subkeys[0], (dimension_x,))
        x = - jnp.linalg.lstsq(H, Btheta_c)[0] + epsilon
    else:
        x = jax.random.normal(subkeys[0], (dimension_x,))
    if make_x_in_range and return_theta_0:
        return x, H, B, c, theta_0
    elif return_sqrt_H:
        return x, H, B, c, (sqrt_H, U, kappa)
    else:
        return x, H, B, c


def gen_biquadratic_data(
    dimension_x=10,
    dimension_latent_inner=10,
    dimension_latent_outer=10,
    dimension_theta=10,
    return_factors=False,
    seed=0,
):
    # we want to generate G, w, H, B, c
    # H positive semi-definite of dimension (dimension_x, dimension_x),
    # H = sqrt_H^T sqrt_H
    # with sqrt sqrt_H of dimension (dimension_latent_inner, dimension_x)
    # c = sqrt_H^T kappa, where kappa is of dimension (dimension_latent_inner,)
    # B = sqrt_H^T U where U is of dimension (dimension_latent_inner, dimension_theta)
    # G positive semi-definite of dimension (dimension_x, dimension_x),
    # G = sqrt_G^T sqrt_G
    # with sqrt sqrt_G of dimension (dimension_latent_outer, dimension_x)
    # w = sqrt_G^T omega, where omega is of dimension (dimension_latent_outer,)
    key = jax.random.PRNGKey(seed)
    key, *subkeys = jax.random.split(key, 6)
    sqrt_H = jax.random.normal(subkeys[0], (dimension_latent_inner, dimension_x))
    H = sqrt_H.T @ sqrt_H
    kappa = jax.random.normal(subkeys[1], (dimension_latent_inner,))
    c = sqrt_H.T @ kappa
    U = jax.random.normal(subkeys[2], (dimension_latent_inner, dimension_theta))
    B = sqrt_H.T @ U
    sqrt_G = jax.random.normal(subkeys[3], (dimension_latent_outer, dimension_x))
    G = sqrt_G.T @ sqrt_G
    omega = jax.random.normal(subkeys[4], (dimension_latent_outer,))
    w = sqrt_G.T @ omega
    if return_factors:
        return G, w, H, B, c, (sqrt_H, U, kappa, sqrt_G, omega)
    else:
        return G, w, H, B, c
