"""
Fit a conditional TG model via SGD

"""
__date__ = "May 2025"

import numpy as np
import jax.random as jr
import jax.numpy as jnp
from jax import jit, vmap, value_and_grad, vjp
import optax
from jqdm import tqdm

from .stats import flat_stats, get_S


@jit
def loss_func(x, psi_y, A, l2_reg):
    phi = A @ psi_y
    H_hat = get_S(x)
    linear_term = -jnp.inner(phi, H_hat)
    _, f_vjp = vjp(flat_stats, x)
    (jac_term,) = f_vjp(phi)
    jac_term = 0.5 * jnp.sum(jnp.power(jac_term, 2))
    l2_loss = 0.5 * l2_reg * jnp.sum(jnp.power(A, 2))
    return linear_term + jac_term + l2_loss


vec_loss_func = vmap(loss_func, in_axes=[0, 0, None, None])
"""Vectorized loss function"""


def estimate_conditional_params_sgd(
    key,
    X,
    y,
    A=None,
    batch_size=64,
    n_iter=100,
    alpha=0.99,
    opt_state=None,
    l2_reg=0.0,
    replace=True,
    lr=3e-2,
):
    """
    Run SGD to estimate the Torus Graph parameters.

    Parameters
    ----------
    key : jr.PRNGKey
    X : (n,d)
    y : (n,d')
    A : None or jnp.ndarray
        Shape: (2d^2, d')
    batch_size : int, optional
    n_iter : int, optional
    alpha : float, optional
    opt_state : None or Optax optimizer state, optional
    l2_reg : float, optional
    replace : bool, optional
    lr : float, optional

    Returns
    -------
    A : jnp.ndarray
        Shape: (2d^2, d')
    opt_state : Optax optimizer state
    """
    n, d = X.shape
    n_prime, d_prime = y.shape
    assert n == n_prime, f"{n} != {n_prime}"
    if isinstance(key, int):
        key = jr.PRNGKey(key)

    if A is None:
        A = 1e-2 * jnp.ones((2*d**2, d_prime))
    assert A.shape == (2*d**2, d_prime), f"{A.shape} != {(2*d**2, d_prime)}"
    
    optimizer = optax.sgd(lr)
    if opt_state is None:
        opt_state = optimizer.init(A)

    batch_loss_func = lambda Xb, yb, A: jnp.mean(
        vec_loss_func(Xb, yb, A, l2_reg)
    )
    batch_val_grad = value_and_grad(batch_loss_func, argnums=2)

    # Loop with minibatches ...
    pbar = tqdm(range(1, n_iter + 1), desc=f"Optimizing A: {np.nan}")
    smooth_val = None
    for i in pbar:
        key, subkey = jr.split(key)
        idx = jr.choice(subkey, n, shape=(batch_size,), replace=replace)
        Xb, yb = X[idx], y[idx]
        val, grads = batch_val_grad(Xb, yb, A)
        if i == 1:
            smooth_val = val
        else:
            smooth_val = alpha * smooth_val + (1.0 - alpha) * val
        pbar.set_description(f"Optimizing A: {smooth_val:.2f}")
        updates, opt_state = optimizer.update(grads, opt_state)
        A = optax.apply_updates(A, updates)

    return A, opt_state

