import jax
import jax.numpy as jnp
import optax


def expectile_loss(u, tau=0.8):
    w = jnp.where(u > 0, tau, (1.0 - tau))
    return w * (u**2)

def target_update(hybrid, tau: float):
    new_target_params = jax.tree_util.tree_map(
        lambda p, tp: p * tau + tp * (1 - tau), hybrid.gater_q.params,
        hybrid.gater_q_tgt.params)

    new_target_gater_q_tgt = hybrid.gater_q_tgt.replace(params=new_target_params)

    return hybrid.replace(gater_q_tgt=new_target_gater_q_tgt)

# def update(hybrid, obs, act_batch, mask, thr_raw, dis, tau=0.8, K=0, 
#            k_obs=None, k_act=None, update_target: bool=False
# ):
def update(hybrid, obs, act_batch, mask, thr_raw, dis, tau=0.8, K=0, 
           k_obs=None, k_act=None, pred_lam_q=None, pred_lam_tgt=None, pred_lam_p=None, update_target: bool=False
):
    msk_f = (jnp.asarray(mask, jnp.float32) > 0.5).astype(jnp.float32)
    tau = jnp.asarray(tau, jnp.float32)

    lam_tar = (jnp.asarray(dis > thr_raw).astype(jnp.float32))

    den = jnp.sum(msk_f) + 1e-8
    pos = jnp.sum(lam_tar * msk_f)
    pos_rate = pos / den
    
    # pos_w = jnp.where(
    #     (pos > 0) & (pos < den),
    #     jnp.clip((den - pos)/jnp.maximum(pos, 1.0), 1.0, 20.0),
    #     1.0
    # )
    # per_ex_w = (1.0 + (pos_w - 1.0) * lam_tar) * msk_f
    
    pos_w = 1.
    per_ex_w = msk_f
    
    denom = jnp.sum(per_ex_w) + 1e-8
    
    assert per_ex_w.shape == lam_tar.shape == msk_f.shape == (act_batch.shape[0],) # (B,)

    def lq_loss_fn(lqparams):
        eps = 0.05
        y = lam_tar * (1 - eps) + 0.5 * eps
        # logits_q = hybrid.gater_q.apply({'params': lqparams}, obs, act_batch).reshape(-1)
        logits_q = pred_lam_q(lqparams, obs, act_batch).reshape(-1)
        bce_vec = optax.sigmoid_binary_cross_entropy(logits_q, y)

        loss = jnp.sum(bce_vec * per_ex_w) / denom
        
        metrics = {
            'gater_q_loss': loss,
            'lam_target_mean': lam_tar.mean(),
            'pos_rate': pos_rate,
            'pos_weight': pos_w
        }
        return loss, metrics

    new_gater_q, metrics_q = hybrid.gater_q.apply_gradient(lq_loss_fn)
    
    if K > 0:
        obs_p, act_p = k_obs, k_act
        msk_p = jnp.repeat((jnp.asarray(mask, jnp.float32) > 0.5).astype(jnp.float32), K, axis=0)
    else:
        obs_p, act_p = obs, act_batch
        msk_p = (jnp.asarray(mask, jnp.float32) > 0.5).astype(jnp.float32)

    # logits_q_target = hybrid.gater_q_tgt.apply({'params': hybrid.gater_q_tgt.params}, obs_p, act_p)
    logits_q_target = pred_lam_tgt(hybrid.gater_q_tgt.params, obs_p, act_p)
    logits_q_target = jax.lax.stop_gradient(logits_q_target.reshape(-1))
    
    def lp_loss_fn(lpparams):
        # logits_p = hybrid.gater_p.apply({'params': lpparams}, obs).reshape(-1)
        logits_p = pred_lam_p(lpparams, obs).reshape(-1)
        if K > 0:
            logits_p = jnp.repeat(logits_p, K, axis=0)
        
        # diff = logits_q_target - logits_p
        diff = jax.nn.sigmoid(logits_q_target) - jax.nn.sigmoid(logits_p)
        eloss = expectile_loss(diff, tau=tau)
        denom_p = jnp.sum(msk_p) + 1e-8
        loss = jnp.sum(eloss * msk_p) / denom_p
        metrics = {
            'gater_p_loss': loss,
            'lam_p_mean': jax.nn.sigmoid(logits_p).mean(),
            'lam_q_soft_mean': jax.nn.sigmoid(logits_q_target).mean(),
            'qp_gap_mean': jnp.mean(diff),
        }
        return loss, metrics
    
    new_gater_p, metrics_p = hybrid.gater_p.apply_gradient(lp_loss_fn)

    hybrid = hybrid.replace(gater_q=new_gater_q, gater_p=new_gater_p)
    return hybrid, {**metrics_q, **metrics_p}
