from functools import partial, wraps
import argparse
import os
import sys
import time

# JAX / Flax / Optax
import jax
import jax.numpy as jnp
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)
import optax
import distrax

# Other ML libraries
import numpy as np
import ml_collections
import wandb


from models.losses import (
    _adv_ppo_loss_fn,
    _ppo_loss_fn,
)


def update_epoch(update_state, unused, uses_hstate, is_adv=False, config=None):
    """
    PPO-style update over minibatches.
    """
    use_wandb = bool(config.get("USE_WANDB", False))

    # Resolve pruner once so it's a Python constant for the jitted inner fn.
    network = (globals().get('network', None))
    adv_network = (globals().get('adv_network', None))
    _pro_pr = (globals().get("pruner", None) or globals().get("pro_pruner", None))
    _adv_pr = globals().get("adv_pruner", None)
    _pruner = _adv_pr if is_adv else _pro_pr
    _use_pruning = bool(config.get("USE_PRUNING", False) and _pruner is not None)

    # -------- WandB helpers (no-op if disabled) --------
    def _maybe_wandb_log(payload):
        if use_wandb:
            import wandb
            wandb.log(payload, commit=False)

    def _log_sparsity(tag: str, params):
        # host-side callback to log global sparsity
        import numpy as _np
        total = zeros = 0
        for leaf in jax.tree_util.tree_leaves(params):
            arr = _np.asarray(leaf)
            total += arr.size
            zeros += _np.count_nonzero(arr == 0)
        if total > 0:
            _maybe_wandb_log({f"{tag}/sparsity": zeros / total})

    def _update_minbatch(train_state, batch_info):
        if uses_hstate:
            init_hstate, traj_batch, advantages, targets = batch_info
        else:
            traj_batch, advantages, targets = batch_info

        # pick loss + network for the training update
        if is_adv:
            _loss_fn = _adv_ppo_loss_fn
            nt = adv_network
        else:
            _loss_fn = _ppo_loss_fn
            nt = network

        # ---- PRUNING (PPO): pre-forward update ----
        if _use_pruning and (not is_adv) and (_pruner is not None):
            pre_p = _pruner.pre_forward_update(train_state.params, train_state.opt_state)
            train_state = train_state.replace(params=pre_p)

        # ========= 1) PPO BACKWARD + PARAM UPDATE =========
        grad_fn = jax.value_and_grad(_loss_fn, has_aux=True)

        if is_adv:
            (total_loss, (v_loss, a_loss, ent, c_loss)), grads = grad_fn(
                train_state.params,
                init_hstate if uses_hstate else None,
                traj_batch, advantages, targets,
                update_step, config, nt
            )
            # gradient stats logging
            if use_wandb:
                grads_flat, _ = jax.tree.flatten(grads)
                grad_l2s      = [jnp.linalg.norm(g)   for g in grads_flat]
                grad_maxs     = [jnp.max(jnp.abs(g))  for g in grads_flat]
                grad_means    = [jnp.mean(jnp.abs(g)) for g in grads_flat]
                global_norm   = jnp.linalg.norm(jnp.stack(grad_l2s))
                max_abs_grad  = jnp.max(jnp.stack(grad_maxs))
                mean_abs_grad = jnp.mean(jnp.stack(grad_means))

                def _log_all(loss, v_l, a_l, ent_v, cls_l, g_norm, g_max, g_mean):
                    _maybe_wandb_log({
                        "adv/total_loss":      float(loss),
                        "adv/value_loss":      float(v_l),
                        "adv/actor_loss":      float(a_l),
                        "adv/entropy":         float(ent_v),
                        "adv/cls_loss":        float(cls_l),
                        "grad/global_l2_norm": float(g_norm),
                        "grad/max_abs":        float(g_max),
                        "grad/mean_abs":       float(g_mean),
                    })

                # Only register callback when logging is enabled
                jax.debug.callback(
                    _log_all,
                    total_loss, v_loss, a_loss, ent, c_loss,
                    global_norm, max_abs_grad, mean_abs_grad,
                )
        else:
            total_loss, grads = grad_fn(
                train_state.params,
                init_hstate if uses_hstate else None,
                traj_batch, advantages, targets,
                update_step, config, nt
            )

        # apply PPO gradients → this is the *only* weight update
        train_state = train_state.apply_gradients(grads=grads)

        # ---- PRUNING (PPO): post-gradient update ----
        if _use_pruning and (not is_adv) and (_pruner is not None):
            # standard pruning step driven by PPO grads
            post_p = _pruner.post_gradient_update(train_state.params, train_state.opt_state)
            train_state = train_state.replace(params=post_p)
        

        return train_state, total_loss

    (
        train_state,
        init_hstate,
        traj_batch,
        advantages,
        targets,
        rng,
        update_step,
    ) = update_state

    rng, _rng = jax.random.split(rng)

    # Standard PPO batching
    batch_size = config["MINIBATCH_SIZE"] * config["NUM_MINIBATCHES"]
    assert batch_size == config["NUM_STEPS"] * config["NUM_ENVS"], (
        "batch size must be equal to number of steps * number of envs"
    )

    if uses_hstate:
        rng, _rng = jax.random.split(rng)
        permutation = jax.random.permutation(_rng, config["NUM_ENVS"])
        batch = (init_hstate, traj_batch, advantages, targets)

        shuffled_batch = jax.tree_util.tree_map(
            lambda x: jnp.take(x, permutation, axis=1), batch
        )

        minibatches = jax.tree.map(
            lambda x: jnp.swapaxes(
                jnp.reshape(
                    x,
                    [x.shape[0], config["NUM_MINIBATCHES"], -1] + list(x.shape[2:]),
                ),
                1,
                0,
            ),
            shuffled_batch,
        )

        
    else:
        num_minibatches = config["NUM_MINIBATCHES"]
        permutation = jax.random.permutation(_rng, batch_size)
        batch = (traj_batch, advantages, targets)
        batch = jax.tree.map(lambda x: x.reshape((batch_size,) + x.shape[2:] ), batch)
        shuffled_batch = jax.tree.map(lambda x: jnp.take(x, permutation, axis=0), batch)
        minibatches = jax.tree.map(
            lambda x: jnp.reshape(x, [num_minibatches, -1] + list(x.shape[1:])),
            shuffled_batch,
        )

    train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches)

    update_state = (
        train_state,
        init_hstate,
        traj_batch,
        advantages,
        targets,
        rng,
        update_step,
    )
    return update_state, total_loss
