from functools import partial

import jax
import jax.numpy as jnp

import wandb

from ..impl import hooks as hooks_lib
from ..impl import transformer
from ..impl.caching import TransformerCache
from ..impl.hooks import HookRequest
from ..tools import train_utils

def run_and_capture(
    run_function,
    inputs,
    positions,
    lm_params,
    lm_config,
    hooks,
):
    cache = TransformerCache.create(
        positions, lm_config, dtype=jnp.bfloat16, dynamic=False
    )
    *_, raw_capture = run_function(inputs, cache, lm_params, lm_config)
    return hooks_lib.unpack_captured(hooks, raw_capture)


@partial(jax.jit, static_argnames=("pad_id", "bos_id", "eos_id", "dtype"))
def build_token_mask(
    token_ids,
    pad_id,
    bos_id,
    eos_id,
    dtype=jnp.float32,
):
    ignore_ids = jnp.array([pad_id, bos_id, eos_id])
    valid_positions = ~jnp.isin(token_ids, ignore_ids)
    return valid_positions.astype(dtype)


def get_hook_tools(train_config):
    hook_requests = [
        HookRequest(cfg.layer_id, cfg.placement) for cfg in train_config.sae_configs
    ]
    hooks, _ = hooks_lib.capture(*hook_requests)

    run_function = jax.jit(
        partial(
            transformer.run,
            hooks_to_return=hooks,
            hooks_to_stream=frozenset(),
            editor=None,
        ),
        static_argnums=(3,),
    )

    mask_function = partial(
        build_token_mask,
        pad_id=train_config.tokenizer.pad_token_id,
        bos_id=train_config.tokenizer.bos_token_id,
        eos_id=train_config.tokenizer.eos_token_id,
    )

    return run_function, mask_function, hooks


# ---------- sae_utils.py ----------
def maybe_log(
    step: int, per_sae_metrics: list[dict], elapsed_tokens, train_config
) -> None:
    if not hasattr(maybe_log, "_sums"):
        maybe_log._sums = [dict() for _ in per_sae_metrics]
        maybe_log._count = 0

    for sae_id, metrics in enumerate(per_sae_metrics):
        sums = maybe_log._sums[sae_id]
        for k, v in metrics.items():
            sums[k] = sums.get(k, 0.0) + float(v)
    maybe_log._count += 1

    if (step) % train_config.log_every == 0:
        wandb.log({"tokens": elapsed_tokens}, step=step)
        for sae_id, sums in enumerate(maybe_log._sums):
            mean_metrics = {
                f"sae_{sae_id}/{k}": v / maybe_log._count for k, v in sums.items()
            }
            wandb.log(mean_metrics, step=step)

            sums.clear()
        maybe_log._count = 0


def maybe_save(step, states, train_config):
    if step == "final" or step % train_config.checkpoint_every == 0:
        for sae_id, sae_state in enumerate(states):
            checkpoint_dir = f"{train_config.checkpoint_root}/sae_{sae_id}"
            train_utils.save_checkpoint(
                sae_state.params, sae_state.opt_state, step, checkpoint_dir
            )
