"""
Fit a conditional model with SGD.

"""

__date__ = "February 2024"


import numpy as np
import jax.numpy as jnp
import jax.random as jr
from jax import jit, vmap, value_and_grad
import optax
from tqdm import tqdm


from .condition import condition_phi_multiple_pieces
from .von_mises import vm_log_pdf

vectorized_condition = vmap(condition_phi_multiple_pieces, [None] * 5 + [0])
vectorized_vm_log_pdf = vmap(vm_log_pdf)


def loss_function(flat_phi, X, y, d, indices_a, indices_b, l2_reg=0.0):
    i1, i2 = 2 * (d - 1), 4 * (d - 1)
    phi_ab = flat_phi[:i1].reshape(d - 1, 1, 1, 2)
    phi_ba = flat_phi[i1:i2].reshape(1, d - 1, 1, 2)
    phi_bb = flat_phi[i2:].reshape(1, 1, 1, 2)
    res = vectorized_condition(
        phi_ab, phi_ba, phi_bb, indices_a, indices_b, X
    )  # [n,1,1,1,2]
    res = res[:, 0, 0, 0]  # [n,2]
    log_pdf = vectorized_vm_log_pdf(res[:, 0], res[:, 1], y)
    return -jnp.mean(log_pdf) + l2_reg * 0.5 * jnp.sum(flat_phi**2)


def fit_conditional_sgd(
    key,
    X,
    y,
    flat_phi=None,
    indices_a=None,
    indices_b=None,
    batch_size=64,
    n_iter=100,
    alpha=0.99,
    opt_state=None,
    replace=True,
    lr=3e-2,
    l2_reg=0.0,
    verbose=False,
):
    """ """
    assert len(X) == len(y)
    d = X.shape[1] + 1
    if isinstance(key, int):
        key = jr.PRNGKey(key)
    if flat_phi is None:
        flat_phi = 1e-2 * jnp.ones(4 * d - 2)
    if indices_a is None or indices_b is None:
        idx = jnp.arange(d)
        indices_a, indices_b = idx[:-1], idx[-1:]
    optimizer = optax.sgd(lr)
    if opt_state is None:
        opt_state = optimizer.init(flat_phi)

    my_loss_function = lambda a, b, c: loss_function(
        a, b, c, d, indices_a, indices_b, l2_reg
    )
    batch_val_grad = value_and_grad(jit(my_loss_function), argnums=0)

    # Loop with minibatches ...
    if verbose:
        pbar = tqdm(range(1, n_iter + 1), desc=f"Optimizing phi: {np.nan}")
    else:
        pbar = range(1, n_iter + 1)
    smooth_val = None
    for i in pbar:
        key, subkey = jr.split(key)
        idx = jr.choice(
            subkey, jnp.arange(len(X)), shape=(batch_size,), replace=replace
        )
        Xb, yb = X[idx], y[idx]
        val, grads = batch_val_grad(flat_phi, Xb, yb)
        if verbose:
            smooth_val = val if i == 1 else alpha * smooth_val + (1.0 - alpha) * val
            pbar.set_description(f"Optimizing phi: {smooth_val:.2f}")
        updates, opt_state = optimizer.update(grads, opt_state)
        flat_phi = optax.apply_updates(flat_phi, updates)

    return flat_phi, opt_state


if __name__ == "__main__":
    pass


###
