from __future__ import annotations

from enum import Enum, auto
from typing import Callable, FrozenSet, Optional, Tuple

import jax.numpy as jnp
from flax import struct


class HookType(Enum):
    RESID_PRE = auto()  # after LN, before attention
    ATTN_OUT = auto()  # result of multi‑head attention
    RESID_MID = auto()  # after attn add‑back, before MLP
    MLP_ACT = auto()  # gate activation (optional but handy)
    MLP_OUT = auto()  # FFN output before add‑back
    RESID_POST = auto()  # end‑of‑block residual


@struct.dataclass
class HookRequest:
    layer: int
    kind: HookType
    stream: bool = False  # True -> jax.debug.callback; False -> returned


def capture(
    *reqs: HookRequest,
) -> Tuple[FrozenSet[Tuple[int, HookType]], FrozenSet[Tuple[int, HookType]]]:
    ret, stream = set(), set()
    for r in reqs:
        key = (r.layer, r.kind)
        if r.stream:
            stream.add(key)
        else:
            ret.add(key)
    return frozenset(ret), frozenset(stream)


def unpack_captured(spec: FrozenSet[Tuple[int, HookType]], values) -> dict:
    return {k: v for k, v in zip(sorted(spec), values)}


@struct.dataclass
class Edit:
    layer: int
    kind: HookType
    vec: Optional[jnp.ndarray] = struct.field(
        pytree_node=False,
        default=None,
    )
    scale: float = 1.0
    tok_slice: slice | None = None
    op: str = "add"  # "add" | "replace" | "swap" | "call"
    callback: Optional[Callable[[jnp.ndarray], jnp.ndarray]] = struct.field(
        pytree_node=False,
        default=None,
    )


@struct.dataclass
class ActivationEditor:
    edits: Tuple[Edit, ...] = struct.field(pytree_node=False)

    def apply(self, *, layer: int, kind: HookType, x, token_mask=None):
        for e in self.edits:
            if e.layer != layer or e.kind.value != kind.value:
                continue

            slc = e.tok_slice or slice(None)
            pos_mask = token_mask[:, slc] if token_mask is not None else None

            def masked(update):
                if pos_mask is None:
                    return update
                return update * pos_mask[..., None]

            if e.op == "add":
                x = x.at[:, slc, :].add(masked(e.scale * e.vec))
            elif e.op == "replace":
                x = x.at[:, slc, :].set(
                    jnp.where(pos_mask[..., None], e.vec, x[:, slc, :])
                )
            elif e.op == "swap":
                tmp = x[:, slc, :]
                x = x.at[:, slc, :].set(jnp.where(pos_mask[..., None], e.vec, tmp))
                e.vec = tmp
            elif e.op == "call":
                if e.callback is None:
                    raise ValueError("callback op requested but no function")
                new_val = e.callback(x[:, slc, :]).astype(x.dtype)

                x = x.at[:, slc, :].set(
                    jnp.where(pos_mask[..., None], new_val, x[:, slc, :])
                )
        return x