import jax.numpy as jnp
import jax
import numpy as np
from sklearn.datasets import make_classification
from sgd import sgd, run_alphas_experiment
from utils import save_experiment

jax.config.update("jax_enable_x64", True)


key = jax.random.PRNGKey(42)
# number of samples, number of features
n, p = 100, 10
# batch size
bs = 1
# generate data
A, theta = make_classification(n_samples=n, n_features=p)
A = A / np.linalg.norm(A, 2)
A = jnp.array(A)
theta = jnp.array(theta, dtype=jnp.float64)

# initialization
x0 = np.random.rand(p)
x0 = jnp.array(x0 / np.linalg.norm(x0, 2))

# total number of iterations
n_iter = 200000

mu = 0.05

def ridge(x, theta, idx):
    return jnp.mean((A[idx, :] @ x - theta[idx]) ** 2) + mu * jnp.linalg.norm(x, 2) ** 2


truex, trued, trueclb = sgd(
    ridge,
    theta,
    x0,
    lambda i: 0.01,
    n_iter * 10,
    n,
    key,
    callback=lambda x, dx: ridge(x, theta, jnp.arange(n)),
    batch_size=n,
)
val = ridge(truex, theta, jnp.arange(n))


def callback_logistic(x, dx):
    return np.array(
        [
            ridge(x, theta, jnp.arange(n)) - val,
            np.linalg.norm(dx - trued) / n,
        ]
    )


facs = [0.05, 0.005, 0.0005]
alphas = [(lambda _, fac=fac: fac) for fac in facs]

fs, dfs = run_alphas_experiment(
    val, trued, ridge, theta, x0, alphas, n_iter, n, key, batch_size=1
)
save_experiment(
    ridge,
    fs,
    dfs,
    val,
    trued,
    theta,
    x0,
    facs,
    n_iter,
    n,
    bs,
)
