import jax.numpy as jnp
from jax import jit, grad, hessian, jacfwd
from jax import random
import numpy as np


def sgd(fun, theta, x0, alpha, n_iter, n_samples, key, callback=None, batch_size=1):
    """
    Compute stochastic gradient descent iterates and their derivatives.

    Args:
    fun: function to minimize
    theta: parameter vector
    x0: initial iterate
    alpha: step size policy (function of the iteration number)
    n_iter: number of iterations
    n_samples: number of samples
    key: jax PRNG key
    batch_size: batch size (default: 1, true SGD)
    """
    p = x0.shape[0]
    m = theta.shape[0]
    gradf = jit(grad(fun, argnums=0))
    hessf = jit(hessian(fun, argnums=0))
    crossf = jit(jacfwd(gradf, argnums=1))
    x = x0
    dx = jnp.zeros((p, m))
    clb = callback(x, dx) if callback is not None else None
    clbs = np.zeros((n_iter, *clb.shape)) if clb is not None else None
    clbs[0, ...] = clb
    for i in range(n_iter - 1):
        key, subkey = random.split(key)
        idx = random.choice(subkey, n_samples, shape=(batch_size,), replace=False)
        # Piggyback derivatives
        dx = (
            dx - alpha(i) * hessf(x, theta, idx) @ dx - alpha(i) * crossf(x, theta, idx)
        )
        # One gradient step
        x = x - alpha(i) * gradf(x, theta, idx)
        if callback is not None:
            clb = callback(x, dx)
            clbs[i+1, ...] = clb
    if callback is not None:
        return x, dx, clbs
    else:
        return x, dx


def run_alphas_experiment(
    val_sol, true_jac, fun, theta, x0, alphas, n_iter, n_samples, key, batch_size=1
):
    fs = np.zeros((n_iter, len(alphas)))
    dys = np.zeros((n_iter, len(alphas)))

    def callback(x, dx):
        return np.array(
            [
                fun(x, theta, jnp.arange(n_samples)) - val_sol,
                np.linalg.norm(dx - true_jac) / n_samples,
            ]
        )

    for i, alpha in enumerate(alphas):
        key, subkey = random.split(key)
        xs, dxs, clbs = sgd(
            fun,
            theta,
            x0,
            alpha,
            n_iter,
            n_samples,
            subkey,
            callback=callback,
            batch_size=batch_size,
        )
        fs[:, i] = clbs[:, 0]
        dys[:, i] = clbs[:, 1]

    return fs, dys
