import torch
from torch import nn
from contextlib import contextmanager
from accelerate.hooks import ModelHook


@contextmanager
def _disable_hooks(sae):
    """
    No-op for local SAEs; kept for API parity.
    """
    yield


class AmlifySAEHook(ModelHook):
    """
    Forward hook that amplifies selected SAE features on the *last* token and
    writes the reconstructed delta back into the layer output.

    Robust to both [B, D] and [B, T, D] layer outputs:
    - If output is [B, D], it is temporarily expanded to [B, 1, D] and then
      squeezed back to [B, D] before returning.
    - If the original module returns a tuple, the first element is assumed to be
      the hidden states tensor and the rest are passed through unchanged.
    """
    def __init__(self, layer, sae, features, amp_factor, device) -> None:
        super().__init__()
        self.amp_factor = float(amp_factor)
        self.sae = sae
        self.device = torch.device(device)
        self.layer = int(layer)
        self.features = [int(f) for f in features]

    def __call__(self, module, args, output):
        # Unpack output: it can be a Tensor or a tuple/list
        if isinstance(output, (tuple, list)):
            hidden = output[0]
            others = list(output[1:])
        else:
            hidden = output
            others = None

        if not torch.is_tensor(hidden):
            raise RuntimeError(f"Expected tensor from layer, got {type(hidden)}")

        # Save original shape/type/device to restore later
        orig_device = hidden.device
        orig_dtype = hidden.dtype
        squeezed = False

        # Normalize to [B, T, D]
        if hidden.ndim == 2:  # [B, D] -> [B, 1, D]
            hidden = hidden.unsqueeze(1)
            squeezed = True
        elif hidden.ndim != 3:
            raise RuntimeError(f"Expected 2D/3D tensor, got {hidden.ndim}D")

        # Move to SAE device if needed
        if hidden.device != self.device:
            hidden = hidden.to(self.device)

        # SAE encode (shape: [B, T, F]) – utils uses ellipsis einsum, so any rank >=2 is fine
        feature_acts = self.sae.encode(hidden)  # [B, T, F]

        # Clean reconstruction (no amplification)
        with torch.no_grad():
            with _disable_hooks(self.sae):
                feature_acts_clean = self.sae.encode(hidden)           # [B, T, F]
                x_reconstruct_clean = self.sae.decode(feature_acts_clean)  # [B, T, D]

        # Amplify selected features *on the last token only*
        last_feats = feature_acts[:, -1, :]  # [B, F]
        F = last_feats.shape[-1]
        idx = torch.as_tensor(self.features, device=last_feats.device, dtype=torch.long)
        idx = idx[(idx >= 0) & (idx < F)]
        if idx.numel() > 0:
            # Per-batch max activation on the last token (scalar per batch)
            max_val = last_feats.amax(dim=-1, keepdim=True)  # [B, 1]
            # Scatter-add the same scaled value to each selected feature index
            src = max_val.expand(last_feats.size(0), idx.numel()) * self.amp_factor  # [B, K]
            last_feats = last_feats.scatter_add(dim=-1, index=idx.unsqueeze(0).expand(last_feats.size(0), -1), src=src)
            feature_acts[:, -1, :] = last_feats

        # Decode amplified activations to get delta
        sae_out = self.sae.decode(feature_acts)  # [B, T, D]

        # Residual error compensation
        sae_out = sae_out + (hidden.to(torch.float32) - x_reconstruct_clean.to(torch.float32))
        sae_out = sae_out.to(orig_dtype)

        # Restore original rank and device
        if squeezed:  # back to [B, D]
            sae_out = sae_out.squeeze(1)
        if sae_out.device != orig_device:
            sae_out = sae_out.to(orig_device)

        # Repack return to match original output type
        if others is not None:
            return tuple([sae_out] + others)
        else:
            return sae_out


def init_hook(pipeline, sae, layer, feature, device, args):
    """
    Register the forward hook on the specified decoder layer.
    """
    sae_hook = AmlifySAEHook(layer, sae, [feature], args.amp_factor, device)
    model_block_to_hook = pipeline.model.model.layers[layer]
    handle = model_block_to_hook.register_forward_hook(sae_hook, always_call=True)
    return handle
