import jax
from jax import numpy as jnp, random, jit, vmap, grad
from jax.numpy import linalg as jla
from scipy.special import factorial
from ortho_poly import get_sigma, get_smoothed_sigma
from math import ceil
import wandb
import tyro


def experiment(
    k: int,
    d: int,
    batch_multiplier: float,
    lr_multiplier: float,
    log_smoothing: float,
    noise_std: float,
    seed: int,
):
    smoothing = jnp.sqrt(d ** (2 * log_smoothing) - 1)
    batch_size = batch_multiplier * d ** (k / 2 - log_smoothing * (k - 2))
    batch_size = int(min(batch_size, 8192))
    lr = (
        lr_multiplier
        * batch_size
        * d ** (-k / 2 + log_smoothing * (2 * k - 2))
        / factorial(k)
    )
    max_n = int(d**k / lr_multiplier)

    config = locals()

    wstar_key, init_key, batch_key = random.split(random.PRNGKey(seed), 3)

    he_coef = [0] * k + [1]
    he_coef = jnp.array(he_coef)
    norm = sum(c**2 / factorial(i) for i, c in enumerate(he_coef))
    he_coef /= jnp.sqrt(norm)
    sigma = jit(get_sigma(he_coef))
    smoothed_sigma = jit(get_smoothed_sigma(he_coef, d))

    normalize = lambda w: w / jla.norm(w)
    logd = lambda x: jnp.log(x) / jnp.log(d)

    wstar = normalize(random.normal(wstar_key, (d,)))
    fstar = lambda x: sigma(x @ wstar)

    @jit
    def loss_fn(w, key):
        w = normalize(w)

        x_key, noise_key = random.split(key)
        x = random.normal(x_key, (batch_size, d))
        z = random.normal(noise_key, (batch_size,)) * noise_std
        y = fstar(x) + z
        out = vmap(smoothed_sigma, (None, 0, None))(w, x, smoothing)
        return jnp.mean(1 - y * out)

    @jit
    def step_fn(w, key):
        key, subkey = random.split(key)
        g = grad(loss_fn)(w, subkey)
        w = normalize(w - lr * g)
        return w, key

    wandb.init(
        project="single_index_smoothing",
        config=config,
    )
    alpha0 = d ** (-1 / 2)
    w = random.normal(init_key, (d,))
    w = jnp.sqrt(1 - alpha0**2) * normalize(w - wstar * (w @ wstar)) + wstar * alpha0

    save_at = 1
    save_multiplier = 1.2
    for i in range(max_n // batch_size):
        alpha = w @ wstar
        n = i * float(batch_size)
        if alpha**2 > 0.5:
            wandb.log({"alpha": alpha, "n": n, "logdn": logd(n + 1)}, step=i)
            break
        if i + 1 == save_at:
            wandb.log({"alpha": alpha, "n": n, "logdn": logd(n + 1)}, step=i)
            save_at = int(ceil(save_multiplier * save_at))
        w, batch_key = step_fn(w, batch_key)
    wandb.run.log_code()
    exit_code = 0 if alpha**2 > 0.5 else -1
    wandb.finish(exit_code)


tyro.cli(experiment)
