# JAX / Flax / Optax
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
from flax.core.frozen_dict import freeze
import optax
import distrax

# Other ML libraries
import numpy as np
import ml_collections
import wandb

# Logging utilities
from batch_logging import EpisodeLogger

# Models
from models.simple import (
    ActorCritic,
    ActorCriticRNN,
    ScannedRNN,
    AdversaryActorCritic,
    AdversaryActorCriticRNN,
)
# Custom utilities
from functools import partial

import jaxpruner
import ml_collections
import os, sys
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from pruning.api import create_updater_from_config

def setup_protagonist(rng, config):
    action_dim = config["ACTION_SHAPE"]
    use_rnn = config["USE_RNN"]

    if use_rnn:
        network = ActorCriticRNN(
            action_dim,
            config.get("LAYER_SIZE", 256),
        )
    else:
        network = ActorCritic(
            action_dim,
            layer_size=config["LAYER_SIZE"],
            activation=config["ACTIVATION"],
        )

    rng, init_rng = jax.random.split(rng)
    obs_shape = config["OBS_SHAPE"]

    if use_rnn:
        dummy_obs = jnp.zeros((1, config["NUM_ENVS"], *obs_shape))
        dummy_done = jnp.zeros((1, config["NUM_ENVS"]))
        init_hstate = ScannedRNN.initialize_carry(
            config["NUM_ENVS"], config["LAYER_SIZE"]
        )
        network_params = network.init(
            init_rng, init_hstate, (dummy_obs, dummy_done)
        )
    else:
        dummy_obs = jnp.zeros((1, *obs_shape))
        network_params = network.init(init_rng, dummy_obs)
        init_hstate = None

    # base optimizer
    tx = optax.chain(
        optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
        optax.adam(config["LR"], eps=1e-5),
    )

    # ----- optional pruning -----
    pruner = None
    if config.get("USE_PRUNING", False):
        # derive schedule from config (robust if NUM_UPDATES not precomputed)
        num_updates = int(config.get(
            "NUM_UPDATES",
            (config["TOTAL_TIMESTEPS"] // config["NUM_STEPS"] // config["NUM_ENVS"])
        ))

        
        sparsity_cfg = ml_collections.ConfigDict()
        num_grads_per_update = int(config["UPDATE_EPOCHS"] * config["NUM_MINIBATCHES"])
        
        num_grad_updates = int(num_updates * num_grads_per_update)
        sparsity_cfg.update_end_step    = int(num_grad_updates)
        sparsity_cfg.update_freq        = 1
        sparsity_cfg.update_start_step  = int(config.get("PRUNE_BURNIN", 0.0) * num_grad_updates)
        
        sparsity_cfg.algorithm          = config.get("PRUNER_TYPE", "magnitude")
        
        sparsity_cfg.sparsity           = float(config.get("PRUNE_PERCENTAGE", 0.5))
        sparsity_cfg.dist_type          = config.get("PRUNE_DIST_TYPE", "erk")
        sparsity_cfg.schedule_power     = config.get("PRUNE_SCHEDULE_POWER", 2)
        
        pruner = create_updater_from_config(sparsity_cfg)
        tx = pruner.wrap_optax(tx)
        print(f"[pruning] enabled: algo={sparsity_cfg.algorithm}, sparsity={sparsity_cfg.sparsity}, "
                f"start_step={sparsity_cfg.update_start_step}, end_step={sparsity_cfg.update_end_step}, "
                f"freq={sparsity_cfg.update_freq}, dist={sparsity_cfg.dist_type}")

    train_state = TrainState.create(
        apply_fn=network.apply,
        params=network_params,
        tx=tx,
    )

    return rng, network, train_state, init_hstate, pruner

def setup_antagonist(rng, config):
    action_dim = config["ACTION_SHAPE"]
    use_rnn = config["USE_RNN"]          # protagonist RNN flag (kept; adv has its own)
    obs_shape = config["OBS_SHAPE"]

    adv_action_dim = 4
    rng, init_rng = jax.random.split(rng)

    adv_network = None
    adv_train_state = None
    init_adv_hstate = None

    # build adversary only if enabled (ATLA)
    if config.get("USE_ATLA", False):
        if config.get("USE_ADV_RNN", True):
            dummy_obs_seq = jnp.zeros((1, config["NUM_ENVS"], *obs_shape))
            dummy_done_seq = jnp.zeros((1, config["NUM_ENVS"]))
            init_adv_hstate = ScannedRNN.initialize_carry(
                config["NUM_ENVS"], config["ADV_LAYER_SIZE"]
            )
            ac_in = (dummy_obs_seq, dummy_done_seq)
            adv_network = AdversaryActorCriticRNN(
                action_dim=adv_action_dim,
                layer_size=config["ADV_LAYER_SIZE"],
            )
            adv_network_params = adv_network.init(init_rng, init_adv_hstate, ac_in)
        else:
            dummy_obs = jnp.zeros((1,) + obs_shape)
            adv_network = AdversaryActorCritic(
                adv_action_dim,
                layer_size=config["ADV_LAYER_SIZE"],
            )
            adv_network_params = adv_network.init(init_rng, dummy_obs)

        # single optimizer head (no classifier branch)
        adv_tx = optax.chain(
            optax.clip_by_global_norm(config["MAX_GRAD_NORM"]),
            optax.adam(config["ADV_LR"], eps=1e-5),
        )

        adv_train_state = TrainState.create(
            apply_fn=adv_network.apply,
            params=adv_network_params,
            tx=adv_tx,
        )

    return rng, adv_network, adv_train_state, init_adv_hstate
