import jax
import jax.numpy as jnp
import os
from functools import partial, wraps
import orbax.checkpoint as ocp
from flax.training import checkpoints
from typing import NamedTuple, Optional, Any, Dict, Tuple
from craftax.craftax_env import make_craftax_env_from_name
from wrappers import (
    AutoResetEnvWrapper,
    BatchEnvWrapper,
    BraxGymnaxWrapper,
    ClipAction,
    LogWrapper,
    NormalizeVecObservation,
    NormalizeVecReward,
    OptimisticResetVecEnvWrapper,
    VecEnv,
)
from toy_envs import FortEnv
import numpy as np
from flax.training import orbax_utils
from orbax.checkpoint import (
    CheckpointManager,
    CheckpointManagerOptions,
    PyTreeCheckpointer,
)
import shutil


def _device_get_tree(x):
    """Bring a small pytree to host (for stats, configs, etc.)."""
    return jax.tree.map(lambda a: np.array(jax.device_get(a)), x)


def build_checkpoint_payload(config, runner_state):
    """
    Package everything to save in one dict:
      - 'runner_state': training state
      - 'norm_stats':   {'obs': {mean,var,count} or None, 'rew': {…} or None}
      - 'config':       host-side config dict
    """
    env_state = runner_state[2]  # (train_state, adv_train_state, env_state, ...)

    obs_stats, rew_stats = _extract_norm_stats_from_env_state(env_state)

    # Ensure small stats are host numpy
    if obs_stats is not None:
        obs_stats = _device_get_tree(obs_stats)
    if rew_stats is not None:
        rew_stats = _device_get_tree(rew_stats)

    payload = {
        "runner_state": runner_state,
        "norm_stats": {
            "obs": obs_stats,   # dict or None
            "rew": rew_stats,   # dict or None
        },
        "config": dict(config),
    }
    return payload


def save_checkpoint_with_norm(config, runner_state, step, ckpt_root="./checkpoints"):
    """
    Save a single Orbax checkpoint containing runner_state + normalization stats + config.
    Returns the absolute checkpoint directory path.
    """
    ckdir_name = _format_ckpt_dir(config)
    ckdir = os.path.abspath(os.path.join(ckpt_root, ckdir_name))
    # Remove the existing run directory if it exists
    if os.path.isdir(ckdir):
        shutil.rmtree(ckdir)
    os.makedirs(ckdir, exist_ok=True)

    payload = build_checkpoint_payload(config, runner_state)

    checkpointer = PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(payload)
    mgr = CheckpointManager(ckdir, checkpointer, CheckpointManagerOptions(max_to_keep=1, create=True))

    mgr.save(int(step), payload, save_kwargs={"save_args": save_args})
    return ckdir


def load_checkpoint_with_norm(ckdir, step=None):
    """
    Restore the latest (or a specific) checkpoint from `ckdir`.
    Returns a dict with keys: 'runner_state', 'norm_stats', 'config'.
    """
    checkpointer = PyTreeCheckpointer()
    mgr = CheckpointManager(ckdir, checkpointer, CheckpointManagerOptions(max_to_keep=1, create=True))

    if step is None:
        step = mgr.latest_step()
        if step is None:
            raise FileNotFoundError(f"No checkpoints found in {ckdir}")

    payload = mgr.restore(step)
    return payload


# ----------------------------
# Data structures
# ----------------------------

class Transition(NamedTuple):
    done: jnp.ndarray
    action: jnp.ndarray
    value: jnp.ndarray
    reward: jnp.ndarray
    log_prob: jnp.ndarray
    obs: jnp.ndarray
    info: Optional[jnp.ndarray] = None
    action_hist: Optional[jnp.ndarray] = None
    next_obs: Optional[jnp.ndarray] = None


# ----------------------------
# Environment loading
# ----------------------------

def load_env(config, norm_stats=None):
    if "craftax" in config["ENV_NAME"].lower():
        env = make_craftax_env_from_name(
            config["ENV_NAME"], not config["USE_OPTIMISTIC_RESETS"]
        )
        env_params = env.default_params

        env = LogWrapper(env)

        if config["USE_OPTIMISTIC_RESETS"]:
            env = OptimisticResetVecEnvWrapper(
                env,
                num_envs=config["NUM_ENVS"],
                reset_ratio=min(config["OPTIMISTIC_RESET_RATIO"], config["NUM_ENVS"]),
            )
        else:
            env = AutoResetEnvWrapper(env)
            env = BatchEnvWrapper(env, num_envs=config["NUM_ENVS"])
    elif "fort" in config["ENV_NAME"].lower():
        env, env_params = FortEnv(max_countdown=3), None
        env = LogWrapper(env)
        env = VecEnv(env)
    else:
        env, env_params = BraxGymnaxWrapper(config["ENV_NAME"]), None
        env = LogWrapper(env)
        env = ClipAction(env)
        env = VecEnv(env)
        if config["NORMALIZE_ENV"]:
            obs_init = None if norm_stats is None else norm_stats.get("obs", None)
            rew_init = None if norm_stats is None else norm_stats.get("rew", None)
            # Removed debug print
            env = NormalizeVecObservation(env, stats=obs_init)
            if norm_stats is None:
                env = NormalizeVecReward(env, config["GAMMA"], stats=rew_init)
    return env, env_params


# ----------------------------
# Preprocessing helpers
# ----------------------------

@partial(jax.jit, static_argnums=[1])
def filter_adv(
    obs: jnp.ndarray,
    config: Dict[str, Any]
) -> jnp.ndarray:
    """
    Preprocess obs for the adversary. If using history buffers, strip to latest.
    """
    if config["USE_STATE_HISTORY"]:
        return obs[:, -1]
    return obs


@partial(jax.jit, static_argnums=[2])
def filter_pro(
    obs: jnp.ndarray,
    done: jnp.ndarray,
    config: Dict[str, Any]
) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """
    Preprocess obs and done for the protagonist policy.
    If history is buffered, return the last frame; otherwise pass-through.
    """
    if config["USE_STATE_HISTORY"]:
        return obs[:, -1], done[:, -1]
    else:
        return obs, done


# ----------------------------
# GAE
# ----------------------------

def calculate_rnn_gae(traj_batch, last_val, last_done, config):
    def _get_advantages(carry, transition):
        gae, next_value, next_done = carry
        done, value, reward = transition.done, transition.value, transition.reward
        delta = reward + config["GAMMA"] * next_value * (1 - next_done) - value
        gae = delta + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
        return (gae, value, done), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val, last_done),
        traj_batch,
        reverse=True,
        unroll=16,
    )
    return advantages, advantages + traj_batch.value


def calculate_gae(traj_batch, last_val, last_done, config):
    def _get_advantages(carry, transition):
        gae, next_value, next_done = carry
        done, value, reward = (
            transition.done,
            transition.value,
            transition.reward,
        )
        delta = (
            reward + config["GAMMA"] * next_value * (1 - next_done) - value
        )
        gae = (
            delta
            + config["GAMMA"] * config["GAE_LAMBDA"] * (1 - next_done) * gae
        )
        return (gae, value, done), gae

    _, advantages = jax.lax.scan(
        _get_advantages,
        (jnp.zeros_like(last_val), last_val, last_done),
        traj_batch,
        reverse=True,
        unroll=16,
    )
    return advantages, advantages + traj_batch.value


# ----------------------------
# bind: inject env and networks
# ----------------------------

def bind(fn, config, env, env_params, networks, pruner=None):
    """
    Returns a wrapped version of `fn` with `config`, `env`, `env_params`,
    and networks injected into its globals. Expects `networks` = (network, adv_network).
    """
    if len(networks) == 2:
        network, adv_network = networks
    else:
        network = networks
        adv_network = None
    fn.__globals__.update({
        'env': env,
        'env_params': env_params,
        'network': network,
        'adv_network': adv_network,
        'pruner': pruner,
    })

    @wraps(fn)
    def wrapped(*args, **kwargs):
        return fn(*args, config=config, **kwargs)
    return wrapped


def _arch_str(config):
    return "rnn" if config.get("USE_RNN", False) else "mlp"


def _prune_str(config):
    if not config.get("USE_PRUNING", False):
        return "none"
    return str(config.get("PRUNER_TYPE", "none")).lower()


def _format_ckpt_dir(config):
    # env, seed, pruning type, pruning %, architecture
    env = config["ENV_NAME"]
    seed = config["SEED"]
    pruner = _prune_str(config)
    pct = float(config.get("PRUNE_PERCENTAGE", 0.0))
    arch = _arch_str(config)
    # Example: hopper-seed42-prune=magnitude-0.50-mlp
    return f"{env}-seed{seed}-prune={pruner}-{pct:.2f}-{arch}"


def _extract_norm_stats_from_env_state(env_state):
    """
    Traverse the wrapped env_state to fetch normalization stats.
    Returns (obs_stats, rew_stats) where each is either a dict or None.
    obs_stats keys: mean, var, count
    rew_stats keys: mean, var, count
    """
    # Reward wrapper (outermost if NORMALIZE_ENV=True)
    rew_stats = None
    cur = env_state
    if hasattr(cur, "mean") and hasattr(cur, "var") and hasattr(cur, "count") and hasattr(cur, "return_val"):
        rew_stats = {
            "mean": np.array(jax.device_get(cur.mean)),
            "var":  np.array(jax.device_get(cur.var)),
            "count": float(np.array(jax.device_get(cur.count))),
        }
        if hasattr(cur, "env_state"):
            cur = cur.env_state

    # Observation wrapper (next inside)
    obs_stats = None
    if hasattr(cur, "mean") and hasattr(cur, "var") and hasattr(cur, "count") and hasattr(cur, "env_state"):
        obs_stats = {
            "mean": np.array(jax.device_get(cur.mean)),
            "var":  np.array(jax.device_get(cur.var)),
            "count": float(np.array(jax.device_get(cur.count))),
        }

    return obs_stats, rew_stats


# LTH

def lth_rewind(init_params, opt_state):
    """Reset unpruned weights to init using masks in opt_state."""
    masks = opt_state.masks  # SparseState has `.masks`
    return jax.tree.map(lambda p0, m: p0 * (m.astype(p0.dtype)), init_params, masks)


def clamp_to_mask(params, opt_state):
    """Clamp params to current mask in opt_state (prevent regrowth)."""
    masks = opt_state.masks
    return jax.tree.map(lambda p, m: p * (m.astype(p.dtype)), params, masks)


def param_sparsity(params):
    leaves = jax.tree_util.tree_leaves(params)
    total = sum([x.size for x in leaves])
    zeros = sum([jnp.sum(x == 0) for x in leaves])
    return zeros / total  # fraction of weights that are exactly zero


def save_sarsa(config: dict, params, step: int, ckpt_root: str = "./checkpoints", dir_suffix: str = "_rscritic"):
    """
    Save SARSA critic parameters to an Orbax checkpoint directory.
    Path is constructed like save_checkpoint_with_norm: join(ckpt_root, _format_ckpt_dir(config)),
    with an optional suffix to avoid clobbering the main run directory.

    Returns:
        Absolute path of the checkpoint directory.
    """
    ckdir_name = _format_ckpt_dir(config) + dir_suffix
    ckdir = os.path.abspath(os.path.join(ckpt_root, ckdir_name))

    # Remove the existing directory if it exists
    if os.path.isdir(ckdir):
        shutil.rmtree(ckdir)
    os.makedirs(ckdir, exist_ok=True)

    payload = {"params": params}

    checkpointer = PyTreeCheckpointer()
    save_args = orbax_utils.save_args_from_target(payload)
    mgr = CheckpointManager(
        ckdir,
        checkpointer,
        CheckpointManagerOptions(max_to_keep=1, create=True),
    )

    mgr.save(int(step), payload, save_kwargs={"save_args": save_args})
    return ckdir


def load_sarsa(ckpt_path: str):
    """
    Load SARSA critic parameters from a checkpoint directory.

    Args:
        ckpt_path: directory containing the checkpoint.

    Returns:
        params: restored parameters.
    """
    ckdir = os.path.abspath(ckpt_path)

    checkpointer = PyTreeCheckpointer()
    mgr = CheckpointManager(
        ckdir, checkpointer, CheckpointManagerOptions(max_to_keep=1, create=True)
    )

    step = mgr.latest_step()
    if step is None:
        raise FileNotFoundError(f"No checkpoints found in {ckdir}")

    payload = mgr.restore(step)
    params = payload["params"]

    return params
