import jax
import jax.numpy as jnp
import optax

def _l1_loss(pred, tgt):
    diff = (pred - tgt).astype(jnp.float32)
    return jnp.mean(jnp.abs(diff))

# def update(pe, obs, act, target):
#     def pA_loss_fn(pA_params):
#         pred_A = pe.pA.apply({'params': pA_params}, obs, act)
#         loss_A = jnp.mean((pred_A - target)**2)
#         # loss_A = _l1_loss(pred_A, target)
#         return loss_A, {'pA_loss': loss_A}

#     new_pA, infoA = pe.pA.apply_gradient(pA_loss_fn)

#     def pB_loss_fn(pB_params):
#         pred_B = pe.pB.apply({'params': pB_params}, obs, act)
#         loss_B = jnp.mean((pred_B - target)**2)
#         # loss_B = _l1_loss(pred_B, target)
#         return loss_B, {'pB_loss': loss_B}

#     new_pB, infoB = pe.pB.apply_gradient(pB_loss_fn)
#     new_pe = pe.replace(pA=new_pA, pB=new_pB)
#     info = {**infoA, **infoB}
#     return new_pe, info

def update(pe, obs, act, target, pa_pred, pb_pred):
    def pA_loss_fn(pA_params):
        pred_A = pa_pred(pA_params, obs, act)
        loss_A = jnp.mean((pred_A - target)**2)
        # loss_A = _l1_loss(pred_A, target)
        return loss_A, {'pA_loss': loss_A}

    new_pA, infoA = pe.pA.apply_gradient(pA_loss_fn)

    def pB_loss_fn(pB_params):
        pred_B = pb_pred(pB_params, obs, act)
        loss_B = jnp.mean((pred_B - target)**2)
        # loss_B = _l1_loss(pred_B, target)
        return loss_B, {'pB_loss': loss_B}

    new_pB, infoB = pe.pB.apply_gradient(pB_loss_fn)
    new_pe = pe.replace(pA=new_pA, pB=new_pB)
    info = {**infoA, **infoB}
    return new_pe, info


