import jax
import jax.numpy as jnp
from typing import Optional, Tuple, Callable, Any
from functools import partial

from jax_rl.networks.common import InfoDict, Model


# def make_disagreement_jit(pA_apply: Callable, pB_apply: Callable):
#     @jax.jit
#     def fn(pA_params, pB_params, obs, act):        
#         next_A = pA_apply({'params': pA_params}, obs, act)
#         next_B = pB_apply({'params': pB_params}, obs, act)
#         diff = jnp.asarray(next_A - next_B, jnp.float32)
#         dis = jnp.mean(diff * diff, axis=-1)
#         # dis = jnp.linalg.norm(diff, ord=1, axis=-1) / diff.shape[-1]
#         dis = jnp.where(jnp.isfinite(dis), dis, 0.)
#         return jnp.clip(dis, 0., 1e6)
#     return fn

def make_disagreement_jit(pa_pred: Callable, pb_pred: Callable):
    @jax.jit
    def fn(pA_params, pB_params, obs, act):        
        next_A = pa_pred(pA_params, obs, act)
        next_B = pb_pred(pB_params, obs, act)
        diff = jnp.asarray(next_A - next_B, jnp.float32)
        dis = jnp.mean(diff * diff, axis=-1)
        dis = jnp.where(jnp.isfinite(dis), dis, 0.)
        return jnp.clip(dis, 0., 1e6)
    return fn

def make_update_stats_from_dis_jit():
    @jax.jit
    def fn(gater_stats, dis_vec,  step:int=1):        
        dis_vec = jnp.asarray(dis_vec, jnp.float32)
        dis_vec = dis_vec if (dis_vec.ndim == 1) else jnp.ravel(dis_vec)
        dis_vec = jnp.where(jnp.isfinite(dis_vec), dis_vec, 0.)
        
        stats = gater_stats.update(dis_vec)
        stats = stats.refresh_threshold(step)        
        return stats
    return fn

def make_update_stats_from_dis_batched_jit():
    @jax.jit
    def fn(gater_stats, dis_mat, start_step: int):
        eps=1e-8
        dis_mat = jnp.asarray(dis_mat, jnp.float32)

        def body(carry, dis_vec):
            stats, step = carry
            dis_vec = jnp.ravel(dis_vec)
            stats = stats.update(dis_vec)
            
            thr_soft_now = stats.mean + stats.k_std * jnp.sqrt(stats.var+eps)
            thr_raw_now = jnp.maximum(jnp.expm1(thr_soft_now) / stats.scale, 1e-6)
            thr_used = stats.get_threshold()
            pos_used = jnp.mean(dis_vec > thr_used)
            pos_now = jnp.mean(dis_vec > thr_raw_now)
            return (stats, step + 1), (pos_used, pos_now, thr_raw_now, thr_used)

        (stats_mid, step_last), outs = jax.lax.scan(body, (gater_stats,start_step), dis_mat)
        stats_final = stats_mid.refresh_threshold(step_last)
        
        pos_used_seq, pos_now_seq, thr_now_seq, thr_used_seq = outs

        metrics = {
            "pos_rate_last_used": pos_used_seq[-1],
            "pos_rate_last_now":  pos_now_seq[-1],
            "pos_rate_mean_used": jnp.mean(pos_used_seq),
            "pos_rate_mean_now":  jnp.mean(pos_now_seq),
            "thr_raw_now_last":   thr_now_seq[-1],
            "thr_raw_used_last":  thr_used_seq[-1],
        }
        return stats_final, metrics
    return fn


# def make_update_lam_jit(lam_update_fn: Callable, lam_tgt_update_fn: Callable):
def make_update_lam_jit(lam_update_fn: Callable, lam_tgt_update_fn: Callable, pred_lam_q, pred_lam_tgt, pred_lam_p):
    @jax.jit
    def fn(hybrid, obs, act, mask, thr_raw, dis, tau, tgt_tau, update_target):
        # hybrid, lam_info = lam_update_fn(hybrid, obs, act, mask, thr_raw, dis, tau, 0, None, None, update_target)
        hybrid, lam_info = lam_update_fn(hybrid, obs, act, mask, thr_raw, dis, tau, 0, None, None, 
                                         pred_lam_q, pred_lam_tgt, pred_lam_p, update_target)

        hybrid = jax.lax.cond(
            update_target, 
            lambda h: lam_tgt_update_fn(h, tgt_tau), 
            lambda h: h,
            hybrid)
        return hybrid, lam_info
    return fn

# def make_update_lam_jit_with_policy(lam_update_fn, lam_tgt_update_fn, actor_apply, K:int):
#     assert K > 0 and K% 2 == 0, "K should be a positive even integer."
    
#     def sample_actions_k(actor_params, obs, rng):
#         B = obs.shape[0]
#         half = K // 2
#         lam_row = jnp.concatenate(
#             [jnp.zeros((half,), jnp.float32), jnp.ones((K-half,), jnp.float32)], axis=0
#         )
#         lam = jnp.tile(lam_row[None, :], (B, 1))

#         keys = jax.random.split(rng, K * B + 1)
#         new_rng, ks = keys[0], keys[1:]

#         k_obs = jnp.repeat(obs, K, axis=0)
#         lam_flat = lam.reshape(B * K)

#         def one(o, lam_scalar, key):
#             lam_b = jnp.asarray(lam_scalar, jnp.float32)[None]
#             dist = actor_apply({'params': actor_params}, o[None, ...], lam=lam_b)
#             a, _ = dist.sample_and_log_prob(seed=key)
#             return a[0]

#         k_act = jax.vmap(one)(k_obs, lam_flat, ks)
#         return k_obs, k_act, new_rng

#     # @partial(jax.jit, static_argnums=(6,9,10))
#     @jax.jit
#     def fn(hybrid, obs, act_batch, mask, thr_raw, dis, tau, actor_params, rng, tgt_tau, update_target):
#         k_obs, k_act, new_rng = sample_actions_k(actor_params, obs, rng)
#         hybrid, lam_info = lam_update_fn(hybrid, obs, act_batch, mask, thr_raw, dis, tau, K, k_obs, k_act, update_target)
        
#         hybrid = jax.lax.cond(
#             update_target, 
#             lambda h: lam_tgt_update_fn(h, tgt_tau), 
#             lambda h: h,
#             hybrid)
#         return hybrid, lam_info, new_rng
#     return fn


# def make_update_lam_jit_with_policy(lam_update_fn, lam_tgt_update_fn,  actor_apply, K:int):
def make_update_lam_jit_with_policy(lam_update_fn, lam_tgt_update_fn, pred_lam_q, pred_lam_tgt, pred_lam_p, actor_apply, K:int):
    assert K > 0 and K% 2 == 0, "K should be a positive even integer."
    
    def sample_actions_k(actor_params, obs, rng):
        B = obs.shape[0]
        half = K // 2
        lam_row = jnp.concatenate(
            [jnp.zeros((half,), jnp.float32), jnp.ones((K-half,), jnp.float32)], axis=0
        )
        lam = jnp.tile(lam_row[None, :], (B, 1))

        keys = jax.random.split(rng, K * B + 1)
        new_rng, ks = keys[0], keys[1:]

        k_obs = jnp.repeat(obs, K, axis=0)
        lam_flat = lam.reshape(B * K)

        def one(o, lam_scalar, key):
            lam_b = jnp.asarray(lam_scalar, jnp.float32)[None]
            dist = actor_apply({'params': actor_params}, o[None, ...], lam=lam_b)
            a, _ = dist.sample_and_log_prob(seed=key)
            return a[0]

        k_act = jax.vmap(one)(k_obs, lam_flat, ks)
        return k_obs, k_act, new_rng

    # @partial(jax.jit, static_argnums=(6,9,10))
    @jax.jit
    def fn(hybrid, obs, act_batch, mask, thr_raw, dis, tau, actor_params, rng, tgt_tau, update_target):
        k_obs, k_act, new_rng = sample_actions_k(actor_params, obs, rng)
        # hybrid, lam_info = lam_update_fn(hybrid, obs, act_batch, mask, thr_raw, dis, tau, K, k_obs, k_act, update_target)
        hybrid, lam_info = lam_update_fn(hybrid, obs, act_batch, mask, thr_raw, dis, tau, K, k_obs, k_act, 
                                         pred_lam_q, pred_lam_tgt, pred_lam_p, update_target)
        
        hybrid = jax.lax.cond(
            update_target, 
            lambda h: lam_tgt_update_fn(h, tgt_tau), 
            lambda h: h,
            hybrid)
        return hybrid, lam_info, new_rng
    return fn


def make_update_dyn_jit(dyn_update_fn: Callable, pa_pred: Callable, pb_pred: Callable, state_tf: Callable):    
    @jax.jit
    def fn(hybrid, obs, act, next_obs):
        target = state_tf(next_obs) - state_tf(obs)
        return dyn_update_fn(hybrid, obs, act, target, pa_pred, pb_pred)
    return fn


def make_precompute_lambda_p(lam_p_apply):
    @jax.jit
    def fn(lam_p_params, obs, rng, det_lam:bool, bootstrap:bool):
        rng, key = jax.random.split(rng, 2)
        logits = lam_p_apply({'params': lam_p_params}, obs)
        probs = jnp.clip(jax.nn.sigmoid(logits), 1e-6, 1.-1e-6)

        def _learned(_):
            lam = jax.lax.cond(
                det_lam,
                lambda _: jax.random.bernoulli(key, probs).astype(jnp.float32),
                lambda _: probs.astype(jnp.float32),
                operand=None,
            )
            return lam
    
        def _boot(_):
            return jnp.full_like(probs, 0.5, dtype=jnp.float32)

        # lam = jax.lax.cond(bootstrap, _boot, _learned, operand=None)
        lam = _learned(None)
        return jax.lax.stop_gradient(lam)
    return fn

def make_precompute_lambda_pq(lam_q_apply, lam_p_apply):
    @jax.jit
    def fn(lam_q_params, lam_p_params, obs, actions, rng, det_lam:bool, bootstrap:bool):
        rng, lam_p_key, lam_q_key = jax.random.split(rng, 3)

        # lam_q_logits = lam_q_apply({'params': lam_q_params}, obs, actions)
        lam_q_logits = lam_q_apply(lam_q_params, obs, actions)
        lam_q_probs = jnp.clip(jax.nn.sigmoid(lam_q_logits), 1e-6, 1.-1e-6)
        # lam_p_logits = lam_p_apply({'params': lam_p_params}, obs)
        lam_p_logits = lam_p_apply(lam_p_params, obs)
        lam_p_probs = jnp.clip(jax.nn.sigmoid(lam_p_logits), 1e-6, 1.-1e-6)
        
        def _learned(_):
            lam_p = jax.lax.cond(
                det_lam,
                lambda _: jax.random.bernoulli(lam_p_key, lam_p_probs).astype(jnp.float32),
                lambda _: lam_p_probs.astype(jnp.float32),
                operand=None,
            )
            lam_q = jax.lax.cond(
                det_lam,
                lambda _: jax.random.bernoulli(lam_q_key, lam_q_probs).astype(jnp.float32),
                lambda _: lam_q_probs.astype(jnp.float32),
                operand=None,
            )
            return lam_p, lam_q
        
        def _boot(_):
            return (jnp.full_like(lam_p_probs, 0.5, dtype=jnp.float32),
                    jnp.full_like(lam_q_probs, 0.5, dtype=jnp.float32))
        
        # lam_p, lam_q = jax.lax.cond(bootstrap, _boot, _learned, operand=None)
        lam_p, lam_q = _learned(None)
        return jax.lax.stop_gradient(lam_p), jax.lax.stop_gradient(lam_q)
    return fn

def make_precompute_lambda_tgt(lam_q_apply, actor_apply):
    @jax.jit
    def fn(lam_q_params, actor_params, next_obs, rng, lam_p, det_lam:bool, bootstrap:bool):
        rng, actor_key, lam_q_key = jax.random.split(rng, 3)
        
        # lam_p_eff = jax.lax.cond(
        #     bootstrap,
        #     lambda _: jnp.full(next_obs.shape[:-1], 0.5, dtype=jnp.float32),
        #     lambda _: lam_p.astype(jnp.float32),
        #     operand=None,
        # )
        lam_p_eff = lam_p.astype(jnp.float32)
        
        dist = actor_apply({'params': actor_params}, next_obs, lam=lam_p_eff)
        next_actions, next_log_probs = dist.sample_and_log_prob(seed=actor_key)
        
        # lam_q_logits = lam_q_apply({'params': lam_q_params}, next_obs, next_actions)
        lam_q_logits = lam_q_apply(lam_q_params, next_obs, next_actions)
        lam_q_probs = jnp.clip(jax.nn.sigmoid(lam_q_logits), 1e-6, 1.-1e-6)
        
        # lam_tgt = jax.lax.cond(
        #     bootstrap,
        #     lambda _: jnp.full_like(lam_q_probs, 0.5, dtype=jnp.float32),
        #     lambda _: jax.lax.cond(
        #         det_lam,
        #         lambda __: jax.random.bernoulli(lam_q_key, lam_q_probs).astype(jnp.float32),
        #         lambda __: lam_q_probs.astype(jnp.float32),
        #         operand=None,
        #     ),
        #     operand=None
        # )
        lam_tgt = jax.lax.cond(
            det_lam,
            lambda _: jax.random.bernoulli(lam_q_key, lam_q_probs).astype(jnp.float32),
            lambda _: lam_q_probs.astype(jnp.float32),
            operand=None,
        )
        return jax.lax.stop_gradient(lam_tgt), next_actions, next_log_probs
    return fn
        
