import jax
import jax.numpy as jnp
from envs.cmdp import CMDP
from utils import compute_Q_h

@jax.jit
def compute_softmax_pol_F_APMPO(old_policy: jnp.ndarray, rQ_h: jnp.ndarray, uQ_h: jnp.ndarray, eta: float, alpha: float) -> jnp.ndarray:
    """ Compute softmax policy given reward Q and utility Q
    Args:
        rQ (jnp.ndarray): (HxSxA) array, reward Q function
        uQ (jnp.ndarray): (HxSxA) array, utility Q function
        eta (float): dual variable for constraint
        alpha (float): temperature parameter

    Returns:
        policy (jnp.ndarray): (HxSxA) array
    """
    # print(f"shape of rQ: {rQ.shape}, shape of uQ: {uQ.shape}, eta: {eta.shape}, alpha: {alpha.shape}")
    # jax.debug.print("eta: {eta}", eta=eta)
    F = rQ_h + eta * uQ_h
    # new_policy = argmax <F, policy> - alpha * KL(policy || old
    log_old_policy = jnp.log(jnp.clip(old_policy, 1e-8, 1.0))
    new_policy = jax.nn.softmax(log_old_policy + F / alpha, axis=-1)
    new_policy = new_policy / (jnp.sum(new_policy, axis=-1, keepdims=True) + 1e-8)
    return new_policy
    # return jax.nn.softmax(F / 0.1, axis=-1)


@jax.jit
def compute_softmax_pol_APMPO(bonus: jnp.ndarray, cmdp: CMDP, pol, Cr: float, Cu: float, cumulative_eta: float, init_s: int, lam: float, err_vio) -> jnp.ndarray:
    H, S, A = bonus.shape

    def policy_evaluation(i, args):
        rQ, uQ, pol = args
        h = H - i - 1

        thresh = H - h

        rQ_h = compute_Q_h(rQ[h+1], pol[h+1], Cr * bonus[h], cmdp.rew[h], cmdp.P[h], 0, thresh)

        uQ_h = compute_Q_h(uQ[h+1], pol[h+1], Cu * bonus[h], cmdp.utility[h], cmdp.P[h], 0, thresh)


        rQ = rQ.at[h].set(rQ_h)
        uQ = uQ.at[h].set(uQ_h)
        return rQ, uQ, pol
    
    def policy_improvement(i, args):
        rQ, uQ, pol = args
        h = H - i - 1

        pol_h = compute_softmax_pol_F_APMPO(pol[h], rQ[h], uQ[h], eta, alpha)
        pol = pol.at[h].set(pol_h)

        return rQ, uQ, pol

    rQ = jnp.zeros((H+1, S, A))
    uQ = jnp.zeros((H+1, S, A))
    additional_pol = jnp.ones((S, A)) / A
    pol = jnp.concat([pol, jnp.expand_dims(additional_pol, axis=0)], axis=0)
    # print("Initial pol shape:", pol.shape)
    # pol = jnp.ones((H+1, S, A)) / A

    args = rQ, uQ, pol
    rQ, uQ, pol = jax.lax.fori_loop(0, H, policy_evaluation, args)
    


    uV = jnp.sum(uQ * pol, axis=-1)
    total_util = cmdp.init_dist @ uV[0]
    eta = lam * jnp.maximum(cmdp.const - total_util, 0)
    cumulative_eta = cumulative_eta + jnp.power(eta, 2) + 1.0
    # print(f"eta: {eta}")
    alpha = jnp.sqrt(cumulative_eta)

    args = rQ, uQ, pol
    rQ, uQ, pol = jax.lax.fori_loop(0, H, policy_improvement, args)

    # new_pol = compute_softmax_pol_F_APMPO(pol, rQ, uQ, eta, alpha)
    return total_util, pol[:-1], cumulative_eta, eta, alpha