"""Implementations of algorithms for continuous control."""
import time
from typing import Optional, Sequence, Tuple, Union
from functools import partial

import flax
import jax
import jax.numpy as jnp
import numpy as np
import optax

from jax_rl.agents.actor_critic_temp import ActorCriticTemp, Gater, Pre
from jax_rl.agents.perl_factory import *
from jax_rl.agents.sac import temperature
from jax_rl.agents.pe_sac import pe_actor, pe_critic, gater, dynamics
from jax_rl.datasets import Batch
from jax_rl.networks import pe_critic_net, pe_policies, policies, critic_net
from jax_rl.networks.common import InfoDict, Model
from jax_rl.networks.pe_net import LambdaQNet, LambdaPiNet, TransitionPredictor, EquiTransitionPredictor, RunningMeanStd
from emlp.reps import Rep
from emlp.groups import Group
import collections
Batch = collections.namedtuple(
    'Batch',
    ['observations', 'actions', 'rewards', 'masks', 'next_observations'])
ActorCrtiticType = Union[ActorCriticTemp, Gater]


@partial(jax.jit, static_argnums=(6,7))
def _update_jit(sac: ActorCriticTemp, batch: Batch, pre: Optional[Pre],
                discount: float, tau: float, target_entropy: float,
                update_target: bool, fixed_temperature: bool = False,
                 actor_basic_wd: float = 0, actor_equiv_wd: float = 0,
                 critic_basic_wd: float = 0, critic_equiv_wd: float = 0
               ) -> Tuple[ActorCriticTemp, InfoDict]:
    sac, critic_info = pe_critic.update(sac, batch, discount, True,critic_basic_wd,critic_equiv_wd, pre=pre)
    if update_target:
        sac = pe_critic.target_update(sac, tau)
    sac, actor_info = pe_actor.update(sac, batch, actor_basic_wd, actor_equiv_wd, pre=pre)
    
    if not fixed_temperature:
        sac, alpha_info = temperature.update(sac, actor_info['entropy'],target_entropy)
    else:
        alpha_info = {}

    return sac, {**critic_info, **actor_info, **alpha_info}

@partial(jax.jit, static_argnums=(6,))
def _update_jit_critic_only(sac: ActorCriticTemp, batch: Batch, pre: Optional[Pre],
                discount: float, tau: float, target_entropy: float,
                update_target: bool,
                 actor_basic_wd: float = 0, actor_equiv_wd: float = 0,
                 critic_basic_wd: float = 0, critic_equiv_wd: float = 0
               ) -> Tuple[ActorCriticTemp, InfoDict]:
    sac, critic_info = pe_critic.update(sac, batch, discount, True,critic_basic_wd,critic_equiv_wd, pre=pre)
    return sac,critic_info

def clipped_adam(learning_rate,clip_norm=0.5, gan_betas=False):
    if gan_betas:
        b1, b2 = 0.5, 0.999
    else:
        b1, b2 = 0.9, 0.999
    print("B1 = ", b1)
    print("B2 = ", b2)
    return optax.chain(
        optax.clip_by_global_norm(clip_norm),
        optax.scale_by_adam(b1=b1, b2=b2, eps=1e-8, eps_root=0.0),
        optax.scale(-learning_rate)
    )

class PESACLearner(object):
    def __init__(self,
                 seed: int,
                 observations: jnp.ndarray,
                 actions: jnp.ndarray,
                 actor_lr: float = 3e-4,
                 actor_basic_wd: float = 0.,
                 actor_equiv_wd: float = 0.,
                 
                 critic_lr: float = 3e-4,
                 critic_basic_wd: float = 0.,
                 critic_equiv_wd: float = 0.,

                 temp_lr: float = 3e-4,
                 hidden_dims: Sequence[int] = (256, 256),
                 discount: float = 0.99,
                 tau: float = 0.005,
                 target_update_period: int = 1,
                 target_entropy: Optional[float] = None,
                 init_temperature: float = 1.0,
                 symmetry_group: Optional[Group] = None,
                 state_rep: Optional[Rep] = None,
                 action_rep: Optional[Rep] = None,
                 action_std_rep: Optional[Rep] = None,
                 state_transform=lambda x:x,
                 inv_state_transform=lambda x:x,
                 action_transform=lambda x:x,
                 inv_action_transform=lambda x:x,
                 action_space="continuous",
                 perl_value=True,
                 perl_policy=True,
                 small_init=True,
                 middle_rep=None,
                 standardizer=None,
                 clipping=0.5,
                 gan_betas=False,
                 
                 # PE args
                 pA_middle_rep=None,
                 g_hidden_dims: Sequence[int] = (256, 256),
                 d_hidden_dims: Sequence[int] = (256, 256),
                 det_lam=True,
                 gater_q_lr: float = 3e-4,
                 gater_p_lr: float = 3e-4,
                 dyn_lr: float = 3e-4,
                 exp_tau: float = 0.8,
                 exp_samples: int = 4,
                 adaptive_k_std: bool = False,
                 lam_clipping: float = 0.5,
                 dyn_clipping: float = 2.0,
                 thr_steps: int = 1000,
                 gater_bias: float = 0.0,
                 
                 lam_train_start: int = 1000,
                 ):
        
        self.standardizer= standardizer
        action_dim = actions.shape[-1]

        if target_entropy is None:
            self.target_entropy = -action_dim / 2
        else:
            self.target_entropy = target_entropy

        self.tau = tau
        print("tau = ", tau)
        self.target_update_period = target_update_period
        self.discount = discount
        
        self.actor_basic_wd = actor_basic_wd
        self.actor_equiv_wd = actor_equiv_wd

        self.critic_basic_wd = critic_basic_wd
        self.critic_equiv_wd = critic_equiv_wd
        
        self.perl_policy = perl_policy
        self.perl_value = perl_value
        
        self.det_lam = det_lam
        # assert self.det_lam, "only det_lam=True is supported for now."
        self.gater_bias = gater_bias
        self.exp_tau = exp_tau
        self.exp_samples = exp_samples
        self.adaptive_k_std = adaptive_k_std
        self.lam_train_start = lam_train_start
        
        self.state_tf = state_transform
        
        rng = jax.random.PRNGKey(seed)
        rng, actor_key, critic_key, temp_key, gate_q_key, gate_p_key, pA_key, pB_key = jax.random.split(rng, 8)

        pA_ch = hidden_dims if pA_middle_rep is None else len(hidden_dims)*[pA_middle_rep]
        pA_def = EquiTransitionPredictor(
            rep_in=state_rep + action_rep,
            rep_out=state_rep,
            group=symmetry_group,
            ch=pA_ch,
            state_transform=self.state_tf, #state_transform,
            action_transform=action_transform,
            small_init=False,
        )
        pA = Model.create(
            pA_def, 
            inputs=[pA_key, observations, actions],
            tx=clipped_adam(learning_rate=dyn_lr,clip_norm=dyn_clipping,gan_betas=gan_betas)
        )
        state_dim = state_transform(observations).shape[-1]
        pB_def = TransitionPredictor(
            hidden=d_hidden_dims,
            output_dim=state_dim,
            state_transform=state_transform,
            action_transform=action_transform, 
            small_init=False,
        )
        pB = Model.create(
            pB_def, 
            inputs=[pB_key, observations, actions],
            tx=clipped_adam(learning_rate=dyn_lr,clip_norm=dyn_clipping,gan_betas=gan_betas)
        )
        
        gater_q_def = LambdaQNet(
            hidden=g_hidden_dims,
            state_transform=state_transform,
            action_transform=action_transform,
            bias_value=self.gater_bias
        )
        
        gater_q = Model.create(
            gater_q_def,
            inputs=[gate_q_key, observations, actions],
            tx=clipped_adam(learning_rate=gater_q_lr,clip_norm=lam_clipping,gan_betas=gan_betas)
        )
        
        gater_q_tgt = Model.create(
            gater_q_def,
            inputs=[gate_q_key, observations, actions],
            tx=None
        )
        gater_q_tgt = gater_q_tgt.replace(params=gater_q.params)

        gater_p_def = LambdaPiNet(
            hidden=g_hidden_dims,
            state_transform=state_transform,
            bias_value=self.gater_bias
        )
        gater_p = Model.create(
            gater_p_def,
            inputs=[gate_p_key, observations],
            tx=clipped_adam(learning_rate=gater_p_lr,clip_norm=lam_clipping,gan_betas=gan_betas)
        )

        lam_prob = jax.nn.sigmoid(gater_q(observations, actions))
        lam = jax.random.bernoulli(gate_q_key, lam_prob)

        self.gater_stats = RunningMeanStd.init(scale=1., k_std=1.5, eps=1e-8, thr_update_steps=thr_steps)

        print("perl policy is", perl_policy)
        
        assert action_space == 'continuous', "only continuous action space supported for PE-SAC"
        if perl_policy: # Use PE-MLP policy
            ch = hidden_dims if middle_rep is None else len(hidden_dims)*[middle_rep]
            actor_def = pe_policies.PENormalTanhPolicy(state_rep,action_rep,action_std_rep,symmetry_group,ch, hidden_dims, action_dim, 
                                    state_transform, inv_action_transform, small_init=small_init, det_lam=det_lam)
        else:
            actor_def = policies.NormalTanhPolicy(hidden_dims, action_dim,small_init=small_init)
        actor = Model.create(actor_def,
                             inputs=[actor_key, observations, lam] if perl_policy else [actor_key, observations],
                             tx=clipped_adam(learning_rate=actor_lr,clip_norm=clipping,
                                            gan_betas=gan_betas))
        
        # Count total parameters for pA and pB
        # def count_params(params):
        #     return sum([np.prod(p.shape) for p in jax.tree_util.tree_leaves(params)])

        # num_params = count_params(actor.params)
        # # print(pA_middle_rep)
        # print(f"actor total parameters: {num_params}")
        # raise

        if perl_value:
            ch = hidden_dims if middle_rep is None else len(hidden_dims)*[middle_rep]
            critic_def = pe_critic_net.PEDoubleCritic(state_rep, action_rep, symmetry_group, ch, hidden_dims, 
                        state_transform=state_transform, action_transform=action_transform)
        else:
            critic_def = critic_net.DoubleCritic(hidden_dims)
            
        critic = Model.create(critic_def,
                              inputs=[critic_key, observations, actions, lam] if perl_value else [critic_key, observations, actions],
                              tx=clipped_adam(learning_rate=critic_lr,clip_norm=clipping,
                                             gan_betas=gan_betas))

        target_critic = Model.create(
            critic_def, inputs=[critic_key, observations, actions, lam] if perl_value else [critic_key, observations, actions])

        if temp_lr is None:
            self.fixed_temp = True
            temp = Model.create(
                temperature.FixedTemperature(value=init_temperature),
                inputs=[temp_key],   
                tx=None
            )
        else:
            self.fixed_temp = False
            temp = Model.create(temperature.Temperature(init_temperature),
                            inputs=[temp_key],
                            tx=clipped_adam(learning_rate=temp_lr,clip_norm=clipping,gan_betas=gan_betas))

        self.sac = ActorCriticTemp(actor=actor,
                                   critic=critic,
                                   target_critic=target_critic,
                                   temp=temp,
                                   rng=rng)
        self.pe = Gater(gater_q=gater_q,
                        gater_q_tgt=gater_q_tgt,
                        gater_p=gater_p,
                        pA=pA,
                        pB=pB,
                        gater_stats=self.gater_stats)
        
        # we need to store these for avoiding multiple call paths in JIT
        self.pA_apply = self.pe.pA.apply_fn.apply
        self.pB_apply = self.pe.pB.apply_fn.apply
        self.pa_pred = self._make_pa_pred_jit()
        self.pb_pred = self._make_pb_pred_jit()
        
        self.lam_q_apply = self.pe.gater_q.apply_fn.apply
        self.lam_tgt_apply = self.pe.gater_q_tgt.apply_fn.apply
        self.lam_p_apply = self.pe.gater_p.apply_fn.apply
        
        self.lam_q_pred = self._make_lam_q_pred_jit()
        self.lam_tgt_pred = self._make_lam_tgt_pred_jit()
        self.lam_p_pred = self._make_lam_p_pred_jit()
        
        
        # PERL JIT functions    
        # helper functions
        self.disagreement_jit = make_disagreement_jit(self.pa_pred, self.pb_pred)
        # self.precompute_lambda_p = make_precompute_lambda_p(self.pe.gater_p.apply)
        self.precompute_lambda_p = make_precompute_lambda_p(self.lam_p_apply)
        # self.precompute_lambda_pq = make_precompute_lambda_pq(self.pe.gater_q.apply, self.pe.gater_p.apply)
        # self.precompute_lambda_tgt = make_precompute_lambda_tgt(self.pe.gater_q_tgt.apply, self.sac.actor.apply)
        # self.disagreement_jit = make_disagreement_jit(self.pa_pred, self.pb_pred)
        # self.precompute_lambda_p = make_precompute_lambda_p(self.lam_p_pred)
        self.precompute_lambda_pq = make_precompute_lambda_pq(self.lam_q_pred, self.lam_p_pred)
        self.precompute_lambda_tgt = make_precompute_lambda_tgt(self.lam_tgt_pred, self.sac.actor.apply)

        # Update functions
        # self.update_stats_from_dis_jit = make_update_stats_from_dis_jit()
        self.update_stats_from_dis_jit = make_update_stats_from_dis_batched_jit()
        # self.update_dyn_jit = make_update_dyn_jit(dynamics.update, self.state_tf)
        self.update_dyn_jit = make_update_dyn_jit(dynamics.update, self.pa_pred, self.pb_pred, self.state_tf)
        # self.update_pA_jit = make_update_dyn_single_jit('pA')
        # self.update_pB_jit = make_update_dyn_single_jit('pB')
        self._lam_updaters = {}
        self.update_lam_jit = self.get_update_lam(self.exp_samples)
    
        self.step = 1
        
        
    def get_update_lam(self, K:int):
        if K not in self._lam_updaters:
            if K == 0:
                self._lam_updaters[K] = make_update_lam_jit(gater.update, gater.target_update)
            else:
                self._lam_updaters[K] = make_update_lam_jit_with_policy(
                    # gater.update, gater.target_update, actor_apply=self.sac.actor.apply_fn.apply, K=K
                    gater.update, gater.target_update, self.lam_q_pred, self.lam_tgt_pred, self.lam_p_pred,
                    actor_apply=self.sac.actor.apply_fn.apply, K=K
                )
        return self._lam_updaters[K]
    

    def sample_actions(self,
                       observations: np.ndarray,
                       temperature: float = 1.0) -> jnp.ndarray:

        obs = self.standardizer(observations) if self.standardizer is not None else observations
        rng, lam_key, policy_key = jax.random.split(self.sac.rng, 3)
        
        lam = self.precompute_lambda_p(self.pe.gater_p.params, obs, lam_key, self.det_lam, self.step < self.lam_train_start) if self.perl_policy else None

        new_rng, actions = pe_policies.sample_actions(policy_key,
                                               self.sac.actor.apply_fn,
                                               self.sac.actor.params,
                                               self.standardizer(observations) if self.standardizer is not None else observations,
                                               lam,
                                               temperature)
        self.sac = self.sac.replace(rng=new_rng)

        actions = np.asarray(actions)
        return np.clip(actions, -1, 1)

    def update(self, batch: Batch,update_policy=True) -> InfoDict:
        'observations', 'actions', 'rewards', 'masks', 'next_observations'
        if self.standardizer is not None:
            standardized_batch = Batch(self.standardizer(batch.observations),batch.actions,
                batch.rewards,batch.masks,self.standardizer(batch.next_observations))
        else:
            standardized_batch = batch
            
        info = {}

        lam_cur_key, lam_tgt_key, sac_key = jax.random.split(self.sac.rng, 3)
        
        if self.perl_policy and self.perl_value:
            # B = standardized_batch.observations.shape[0]
            
            # if self.step < self.lam_train_start:
            #     lam_p_key, lam_q_key = jax.random.split(lam_cur_key, 2)
            #     lam_p = jax.random.bernoulli(lam_p_key, 0.5, (B,)).astype(jnp.float32)
            #     lam_q = jax.random.bernoulli(lam_q_key, 0.5, (B,)).astype(jnp.float32)
            #     lam_tgt = jax.random.bernoulli(lam_tgt_key, 0.5, (B,)).astype(jnp.float32)
                
            #     dist = self.sac.actor.apply_fn.apply(
            #         {'params': self.sac.actor.params}, standardized_batch.next_observations, lam_p
            #     )
            #     next_actions, next_log_probs = dist.sample_and_log_prob(seed=lam_tgt_key)
            
            # else:    
            lam_p, lam_q = self.precompute_lambda_pq(
                self.pe.gater_q.params, self.pe.gater_p.params,
                standardized_batch.observations, standardized_batch.actions,
                lam_cur_key, self.det_lam, self.step < self.lam_train_start
            )
            lam_tgt, next_actions, next_log_probs = self.precompute_lambda_tgt(
                self.pe.gater_q_tgt.params, self.sac.actor.params,
                standardized_batch.next_observations,
                lam_tgt_key, lam_p, self.det_lam, self.step < self.lam_train_start
            )
            pre = Pre(lam_p=lam_p, lam_q=lam_q, lam_tgt=lam_tgt, next_actions=next_actions, next_log_probs=next_log_probs)
        self.sac = self.sac.replace(rng=sac_key)
        
        if update_policy:
            self.step += 1
            self.sac, p_info = _update_jit(
                self.sac, standardized_batch, pre,
                self.discount, self.tau, self.target_entropy,
                self.step % self.target_update_period == 0, self.fixed_temp,
                self.actor_basic_wd, self.actor_equiv_wd,
                self.critic_basic_wd, self.critic_equiv_wd)
        else:
            self.sac, p_info = _update_jit_critic_only(
                self.sac, standardized_batch, pre,
                self.discount, self.tau, self.target_entropy,
                self.step % self.target_update_period == 0,
                self.actor_basic_wd, self.actor_equiv_wd,
                self.critic_basic_wd, self.critic_equiv_wd)
            
        if pre is not None:
            info.update({
                'lam/prob_tgt_mean': lam_tgt.mean(),
                'lam/prob_tgt_frac_0.5': (lam_tgt > 0.5).mean(),
                'lam/lam_q_mean': lam_q.mean(),
                'lam/lam_p_mean': lam_p.mean(),
            })
        info.update(p_info)
        
        return info
    
    
    def update_dynamics(self, batch: Batch) -> InfoDict:
        if self.standardizer is not None:
            standardized_batch = Batch(self.standardizer(batch.observations),batch.actions,
                batch.rewards,batch.masks,self.standardizer(batch.next_observations))
        else:
            standardized_batch = batch
        self.pe, d_info = self.update_dyn_jit(
            self.pe, 
            standardized_batch.observations,
            standardized_batch.actions,
            standardized_batch.next_observations
        )
        
        return d_info
    
    # def update_lambda(self, batch) -> InfoDict:        
    #     if self.standardizer is not None:
    #         standardized_batch = Batch(self.standardizer(batch.observations),batch.actions,
    #             batch.rewards,batch.masks,self.standardizer(batch.next_observations))
    #     else:
    #         standardized_batch = batch

    #     dis = self.disagreement_jit(
    #         self.pe.pA.params, self.pe.pB.params,
    #         standardized_batch.observations, standardized_batch.actions)

    #     thr_raw = self.pe.gater_stats.get_threshold().astype(standardized_batch.observations.dtype)

    #     K = self.exp_samples
    #     if K == 0:
    #         self.pe, lam_info = self.update_lam_jit(
    #             self.pe, 
    #             standardized_batch.observations, 
    #             standardized_batch.actions, 
    #             standardized_batch.masks, 
    #             thr_raw, dis, self.exp_tau,
    #             self.tau, self.step % self.target_update_period == 0
    #         )
    #     else:
    #         rng, lam_key = jax.random.split(self.sac.rng, 2)
    #         self.sac = self.sac.replace(rng=rng)

    #         self.pe, lam_info, new_rng = self.update_lam_jit(
    #             self.pe,
    #             standardized_batch.observations,
    #             standardized_batch.actions,
    #             standardized_batch.masks,
    #             thr_raw, dis, self.exp_tau,
    #             self.sac.actor.params, lam_key,
    #             self.tau, self.step % self.target_update_period == 0
    #         )
    #         self.sac = self.sac.replace(rng=new_rng)

    #     lam_info['k_std'] = self.pe.gater_stats.k_std

    #     return lam_info

    # def update_gater_stats(self, observations, actions):
    #     dis = self.disagreement_jit(
    #         self.pe.pA.params, self.pe.pB.params,
    #         observations if observations.ndim >1 else observations[None, :],
    #         actions if actions.ndim >1 else actions[None, :]
    #     )
    #     new_stats = self.update_stats_from_dis_jit(self.pe.gater_stats, dis, jnp.int32(self.step))
    #     self.pe = self.pe.replace(gater_stats=new_stats)
    #     return {
    #         "k_std": self.pe.gater_stats.k_std,
    #         "thr": self.pe.gater_stats.get_threshold(),
    #         'disagreement_mean': self.pe.gater_stats.mean,
    #         'disagreement_std': jnp.sqrt(self.pe.gater_stats.var),
    #     }
    
    def update_lambda(self, batch) -> InfoDict:
        if self.standardizer is not None:
            standardized_batch = Batch(self.standardizer(batch.observations),batch.actions,
                batch.rewards,batch.masks,self.standardizer(batch.next_observations))
        else:
            standardized_batch = batch
        dis = self.disagreement_jit(
            self.pe.pA.params, self.pe.pB.params,
            standardized_batch.observations, standardized_batch.actions)

        thr_raw_used = self.pe.gater_stats.get_threshold().astype(standardized_batch.observations.dtype)

        # --- build the same invariants at *label time*
        x_batch = jnp.log1p(self.pe.gater_stats.scale * jnp.ravel(dis))
        mu_x    = self.pe.gater_stats.mean
        sig_x   = jnp.sqrt(self.pe.gater_stats.var + 1e-8)
        thr_soft = mu_x + self.pe.gater_stats.k_std * sig_x
        thr_raw_now = jnp.maximum(jnp.expm1(thr_soft)/self.pe.gater_stats.scale, 1e-6)

        pos_rate_raw_label  = jnp.mean(jnp.ravel(dis) > thr_raw_used)
        pos_rate_soft_label = jnp.mean(x_batch > thr_soft)

        # --- call your existing JIT update
        K = self.exp_samples
        if K == 0:
            self.pe, lam_info = self.update_lam_jit(
                self.pe,
                standardized_batch.observations,
                standardized_batch.actions,
                standardized_batch.masks,
                thr_raw_used, dis, self.exp_tau,
                self.tau, self.step % self.target_update_period == 0
            )
        else:
            rng, lam_key = jax.random.split(self.sac.rng, 2)
            self.sac = self.sac.replace(rng=rng)
            self.pe, lam_info, new_rng = self.update_lam_jit(
                self.pe,
                standardized_batch.observations,
                standardized_batch.actions,
                standardized_batch.masks,
                thr_raw_used, dis, self.exp_tau,
                self.sac.actor.params, lam_key,
                self.tau, self.step % self.target_update_period == 0
            )
            self.sac = self.sac.replace(rng=new_rng)

        # existing
        lam_info['k_std'] = self.pe.gater_stats.k_std

        # new: threshold/pos invariants at label time
        lam_info.update({
            "thr_used_at_lambda": thr_raw_used,
            "thr_raw_now_at_lambda": thr_raw_now,
            "thr_soft_at_lambda": thr_soft,
            "pos_rate_raw_at_lambda": pos_rate_raw_label,
            "pos_rate_soft_at_lambda": pos_rate_soft_label,
        })

        return lam_info

    
    # def update_gater_stats(self, observations, actions):
    #     # disagreement for this stats batch (raw d = ||Δ||^2)
    #     dis = self.disagreement_jit(
    #         self.pe.pA.params, self.pe.pB.params,
    #         observations if observations.ndim >1 else observations[None, :],
    #         actions if actions.ndim >1 else actions[None, :]
    #     )

    #     thr_prev = self.pe.gater_stats.get_threshold()

    #     new_stats = self.update_stats_from_dis_jit(self.pe.gater_stats, dis, jnp.int32(self.step))
    #     self.pe = self.pe.replace(gater_stats=new_stats)

    #     # -------- wandb-friendly metrics (scalars) --------
    #     # log-domain view for consistency checks
    #     x_batch = jnp.log1p(self.pe.gater_stats.scale * jnp.ravel(dis))
    #     mu_x    = self.pe.gater_stats.mean
    #     sig_x   = jnp.sqrt(self.pe.gater_stats.var + 1e-8)
    #     thr_soft = mu_x + self.pe.gater_stats.k_std * sig_x             # threshold in x-space

    #     # raw threshold: "now" from stats and "used" (EMA'd)
    #     thr_raw_now  = jnp.maximum(jnp.expm1(thr_soft) / self.pe.gater_stats.scale, 1e-6)
    #     thr_raw_used = self.pe.gater_stats.get_threshold()

    #     # consistency: pos-rate in both domains (should be close)
    #     pos_rate_raw  = jnp.mean(jnp.ravel(dis) > thr_raw_used)
    #     pos_rate_soft = jnp.mean(x_batch > thr_soft)

    #     # batch summaries (raw & log)
    #     mean_d = jnp.mean(dis)
    #     std_d  = jnp.sqrt(jnp.mean((dis - mean_d)**2) + 1e-8)
    #     mean_x = jnp.mean(x_batch)
    #     std_x  = jnp.sqrt(jnp.mean((x_batch - mean_x)**2) + 1e-8)

    #     # multiplicative jump of the *instantaneous* raw thr (pre-EMA) vs previous used thr
    #     thr_jump = thr_raw_now / jnp.maximum(thr_prev, 1e-12)

    #     return {
    #         # existing
    #         "k_std": self.pe.gater_stats.k_std,
    #         "thr": thr_raw_used,
    #         "disagreement_mean": mu_x,                 # mean in log domain
    #         "disagreement_std": sig_x,                 # std in log domain

    #         # new: invariants & diagnostics
    #         "thr_raw_used": thr_raw_used,
    #         "thr_raw_now": thr_raw_now,
    #         "thr_soft": thr_soft,
    #         "thr_jump": thr_jump,
    #         "pos_rate_raw": pos_rate_raw,
    #         "pos_rate_soft": pos_rate_soft,
    #         "mean_d": mean_d,
    #         "std_d": std_d,
    #         "mean_x": mean_x,
    #         "std_x": std_x,
    #         "stats_count": self.pe.gater_stats.count,
    #         "thr_beta": self.pe.gater_stats.beta,
    #         "thr_update_steps": self.pe.gater_stats.thr_update_steps,
    #     }

    
    def update_gater_stats_batch(self, batch, step):
        if self.standardizer is not None:
            standardized_batch = Batch(self.standardizer(batch.observations),batch.actions,
                batch.rewards,batch.masks,self.standardizer(batch.next_observations))
        else:
            standardized_batch = batch

        dis_vec = self.disagreement_jit(self.pe.pA.params, self.pe.pB.params,
                                  standardized_batch.observations, standardized_batch.actions)
        dis_mat = dis_vec[None, :]

        new_stats, m = self.update_stats_from_dis_jit(self.pe.gater_stats, dis_mat, jnp.int32(step))
        self.pe = self.pe.replace(gater_stats=new_stats)

        # Coherent snapshot (last row)
        snap = gate_snapshot(new_stats, dis_mat[-1])

        return {
            # core threshold state
            "stats_count": snap["stats_count"],
            "k_std": snap["k_std"],
            "thr": snap["thr_used"],
            "thr_raw_now": snap["thr_raw_now"],
            "thr_soft": snap["thr_soft"],
            # agreement checks
            "pos_rate_raw": snap["pos_rate_raw"],
            "pos_rate_soft": snap["pos_rate_soft"],
            # scan summaries
            "pos_rate_last_used": m["pos_rate_last_used"],
            "pos_rate_last_now":  m["pos_rate_last_now"],
            "pos_rate_mean_used": m["pos_rate_mean_used"],
            "pos_rate_mean_now":  m["pos_rate_mean_now"],
            "thr_raw_now_last":   m["thr_raw_now_last"],
            "thr_raw_used_last":  m["thr_raw_used_last"],
    }
        
    def _make_pa_pred_jit(self):
        pA_apply = self.pA_apply
        @jax.jit
        def fn(pA_params, obs, act):
            return pA_apply({'params': pA_params}, obs, act)
        return fn
    
    def _make_pb_pred_jit(self):
        pB_apply = self.pB_apply
        @jax.jit
        def fn(pB_params, obs, act):
            return pB_apply({'params': pB_params}, obs, act)
        return fn
    
    def _make_lam_q_pred_jit(self):
        gater_q_apply = self.lam_q_apply
        @jax.jit
        def fn(gater_q_params, obs, act):
            return gater_q_apply({'params': gater_q_params}, obs, act)
        return fn
    
    def _make_lam_tgt_pred_jit(self):
        gater_tgt_apply = self.lam_tgt_apply
        @jax.jit
        def fn(gater_tgt_params, obs, act):
            return gater_tgt_apply({'params': gater_tgt_params}, obs, act)
        return fn
    
    def _make_lam_p_pred_jit(self):
        gater_p_apply = self.lam_p_apply
        @jax.jit
        def fn(gater_p_params, obs):
            return gater_p_apply({'params': gater_p_params}, obs)
        return fn

    # def log_invariants(self, batch, prefix="dbg/"):
    #     ### Sanity checking, erase later
    #     if self.standardizer is not None:
    #         obs_s = self.standardizer(batch.observations)
    #         nxt_s = self.standardizer(batch.next_observations)
    #     else:
    #         obs_s, nxt_s = batch.observations, batch.next_observations
            
    #     stf = self.state_tf
    #     target = stf(nxt_s) - stf(obs_s)

    #     predA = self.pa_pred(self.pe.pA.params, obs_s, batch.actions)
    #     predB = self.pb_pred(self.pe.pB.params, obs_s, batch.actions)

    #     lossA = jnp.mean((target - predA)**2)
    #     lossB = jnp.mean((target - predB)**2)
        
    #     dis_per = jnp.mean((predA - predB)**2, axis=-1)
    #     dis_mean = jnp.mean(dis_per)        
        
    #     ineq_ratio = dis_mean / jnp.maximum(2 * (lossA + lossB), 1e-6)

    #     return {
    #         f'{prefix}lossA': lossA,
    #         f'{prefix}lossB': lossB,
    #         f'{prefix}dis_mean': dis_mean,
    #         f'{prefix}ineq_ratio': ineq_ratio, # if ineq ratio >> 1, something is wrong
    #     }
        
    # def log_lambdas(self, batch, prefix="dbg/"):
    #     logits_q = self.pe.gater_q.apply_fn.apply(
    #         {'params': self.pe.gater_q.params}, batch.observations, batch.actions)
    #     logits_p = self.pe.gater_p.apply_fn.apply(
    #         {'params': self.pe.gater_p.params}, batch.observations)
        
    #     return {
    #         f'{prefix}lam_q_mean': jnp.mean(jax.nn.sigmoid(logits_q)),
    #         f'{prefix}lam_p_mean': jnp.mean(jax.nn.sigmoid(logits_p)),
    #         f'{prefix}p_diff_L1': jnp.mean(jnp.abs(jax.nn.sigmoid(logits_q) - jax.nn.sigmoid(logits_p))),
    #     }
            
        
        

def gate_snapshot(stats, dis_batch):
    eps = 1e-8
    mu_x  = stats.mean
    sig_x = jnp.sqrt(stats.var + eps)
    k     = stats.k_std
    sc    = stats.scale

    thr_soft     = mu_x + k * sig_x
    thr_raw_now  = jnp.maximum(jnp.expm1(thr_soft) / sc, 1e-6)
    thr_used     = stats.get_threshold()

    x_batch      = jnp.log1p(sc * jnp.ravel(dis_batch))
    pos_raw      = jnp.mean(jnp.ravel(dis_batch) > thr_used)
    pos_soft     = jnp.mean(x_batch > thr_soft)

    return {
        "stats_count": stats.count,
        "k_std": k, "scale": sc,
        "thr_soft": thr_soft,
        "thr_raw_now": thr_raw_now,
        "thr_used": thr_used,
        "pos_rate_raw": pos_raw,
        "pos_rate_soft": pos_soft,
        "dis_mean": mu_x,
        "dis_std": sig_x,
    }