import jax
import jax.numpy as jnp
from envs.cmdp import CMDP
from utils import compute_Q_h

@jax.jit
def compute_softmax_pol_h_Stradi(rQ_h, uQ_h, ent_coef, lam):
    Q_h = rQ_h + lam * uQ_h
    return jax.nn.softmax(Q_h / ent_coef, axis=-1)

@jax.jit
def compute_softmax_pol_Stradi(bonus: jnp.ndarray, cmdp: CMDP, pol, Cr: float, Cu: float, ent_coef: float, Clam: float) -> jnp.ndarray:
    H, S, A = bonus.shape

    def policy_evaluation(i, args):
        rQ, uQ, pol = args
        h = H - i - 1

        thresh = H - h

        rQ_h = compute_Q_h(rQ[h+1], pol[h+1], Cr * bonus[h], cmdp.rew[h], cmdp.P[h], 0, thresh)

        uQ_h = compute_Q_h(uQ[h+1], pol[h+1], Cu * bonus[h], cmdp.utility[h], cmdp.P[h], 0, thresh)


        rQ = rQ.at[h].set(rQ_h)
        uQ = uQ.at[h].set(uQ_h)
        return rQ, uQ, pol
    
    def policy_improvement(i, args):
        rQ, uQ, pol = args
        h = H - i - 1

        pol_h = compute_softmax_pol_h_Stradi(rQ[h], uQ[h], ent_coef, lam)
        pol = pol.at[h].set(pol_h)

        return rQ, uQ, pol

    rQ = jnp.zeros((H+1, S, A))
    uQ = jnp.zeros((H+1, S, A))
    additional_pol = jnp.ones((S, A)) / A
    pol = jnp.concat([pol, jnp.expand_dims(additional_pol, axis=0)], axis=0)
    # print("Initial pol shape:", pol.shape)
    # pol = jnp.ones((H+1, S, A)) / A

    args = rQ, uQ, pol
    rQ, uQ, pol = jax.lax.fori_loop(0, H, policy_evaluation, args)
    
    uV = jnp.sum(uQ * pol, axis=-1)
    total_util = cmdp.init_dist @ uV[0]
    lam = jnp.where(total_util >= cmdp.const, 0.0, Clam)

    args = rQ, uQ, pol
    
    rQ, uQ, pol = jax.lax.fori_loop(0, H, policy_improvement, args)
    
    # new_pol = compute_softmax_pol_F_APMPO(pol, rQ, uQ, eta, alpha)
    return total_util, pol[:-1]
 