import jax
import jax.numpy as jnp

def szgd_inner_jax(f, inner_var, outer_var, state_sampler, step_size, key,
                  sampler=None, n_steps=1, k=1, mu=0.1):
    """
    Jax implementation of stochastic gradient descent on the inner problem.

    Parameters
    ----------
    inner_var : array
        Initial value of the inner variable.
    outer_var : array
        Value of the outer variable.
    state_sampler : dict
        State of the sampler.
    step_size : float
        Step size of the gradient descent.
    sampler : callable
        Sampler for the inner problem.
    n_steps : int
        Number of steps of the gradient descent.
    grad_inner : callable
        Gradient of the inner oracle with respect to the inner variable.
    """
    def iter(i, args):
        state_sampler, inner_var = args
        start_idx, *_, state_sampler = sampler(state_sampler)
        
        U_i = jax.random.normal(key, (k, inner_var.shape[0]))
        inner_var_u = inner_var + mu * U_i
        inner_var_i = inner_var - mu * U_i
        # jax.debug.print("inner_var {}", inner_var)
        # jax.debug.print("inner_var_u {}", inner_var_u)
        # jax.debug.print("inner_var_i {}", inner_var_i)
        def iter_fun_grad_g(q, deltas_in):
        
            inner_value_u = f(
                inner_var_u[q], outer_var, start_idx
            )
            inner_value_i = f(
                inner_var_i[q], outer_var, start_idx
            )
            deltas_in = deltas_in.at[q].set(inner_value_u - inner_value_i)
            # jax.debug.print("inner_value_u - inner_value_i {}",inner_value_u - inner_value_i)
            # jax.debug.print("inner_value_u {}",inner_value_u )
            # jax.debug.print("inner_value_i {}",inner_value_i )
            return deltas_in

        deltas_in = jax.lax.fori_loop(0, k, iter_fun_grad_g, jnp.zeros((k,1)) )
        
        deltas_in = deltas_in / (2* mu)
        es_estimator_gy = (U_i.T.dot(deltas_in)/k).squeeze(1)

        inner_var -= step_size * es_estimator_gy
        # jax.debug.print("es_estimator_gy {}",es_estimator_gy)
        return state_sampler, inner_var
    state_sampler, inner_var = jax.lax.fori_loop(0, n_steps, iter,
                                                 (state_sampler, inner_var))
    # jax.debug.print("a {}",n_steps)
    return inner_var, state_sampler

def sgd_inner_jax(inner_var, outer_var, state_sampler, step_size,
                  sampler=None, n_steps=1, grad_inner=None):
    """
    Jax implementation of stochastic gradient descent on the inner problem.

    Parameters
    ----------
    inner_var : array
        Initial value of the inner variable.
    outer_var : array
        Value of the outer variable.
    state_sampler : dict
        State of the sampler.
    step_size : float
        Step size of the gradient descent.
    sampler : callable
        Sampler for the inner problem.
    n_steps : int
        Number of steps of the gradient descent.
    grad_inner : callable
        Gradient of the inner oracle with respect to the inner variable.
    """
    def iter(i, args):
        state_sampler, inner_var = args
        start_idx, *_, state_sampler = sampler(state_sampler)
        inner_var -= step_size * grad_inner(inner_var, outer_var, start_idx)
        return state_sampler, inner_var
    state_sampler, inner_var = jax.lax.fori_loop(0, n_steps, iter,
                                                 (state_sampler, inner_var))

    return inner_var, state_sampler
