# method.py
"""
Head-Masked Nullspace Steering (HMNS)

Minimal, dependency-free (beyond PyTorch/Transformers) implementation that:
  1) attributes causal heads via masked ablations,
  2) builds a masked write subspace per layer,
  3) samples an orthogonal nullspace direction via thin-QR (float32),
  4) injects a scaled last-position residual nudge,
  5) runs a closed-loop detect→intervene→decode procedure.

Usage (from main.py):
    from method import HMNSConfig, HMNS

    cfg = HMNSConfig(topk_heads=10, max_attempts=3)
    hmns = HMNS(cfg)

    tok, model = ...  # load with your models.py helper
    prompt = "..."

    result = hmns.run_closed_loop(model, tok, prompt)
    print(result["steered_text"])
"""

from __future__ import annotations
import contextlib
from dataclasses import dataclass
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple

import torch
from torch import Tensor


# ---------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------

@dataclass
class HMNSConfig:
    # Loop / selection
    topk_heads: int = 10                  # global K
    max_attempts: int = 3                 # closed-loop attempts
    max_layers_for_attr: Optional[int] = None  # optionally cap shallow layers for speed
    use_proxy_preselection: bool = True   # fast shortlist before exact KL
    proxy_mult: float = 3.0               # shortlist size multiplier (≈ topk_heads * proxy_mult)

    # Steering
    base_alpha: float = 0.25              # α_1; we nudge per-attempt by +10%
    alpha_growth: float = 0.10            # α_t = base_alpha * (1 + alpha_growth*(t-1))
    eps_norm: float = 1e-8                # small ε for norm division
    orth_tol: float = 1e-6                # ||M^T u||_∞ < orth_tol
    orth_resamples: int = 3               # resample attempts if orthogonality fails

    # Decoding
    max_new_tokens: int = 128
    temperature: float = 0.7
    top_p: float = 0.95
    use_cache: bool = False               # must be False for correctness under masking
    truncate_to: Optional[int] = 2048     # truncate input for parity

    # Determinism
    seed: Optional[int] = 0               # None → do not set
    enable_tf32: bool = True              # allow TF32 matmuls when available

    # Safety numerics
    prob_clip_min: float = 1e-9           # avoid log(0) in KL
    device_map_hint: Optional[str] = None # not used here; handled by model loader


# ---------------------------------------------------------------------
# HMNS class
# ---------------------------------------------------------------------

class HMNS:
    def __init__(self, cfg: HMNSConfig):
        self.cfg = cfg
        if cfg.seed is not None:
            torch.manual_seed(cfg.seed)
        if torch.cuda.is_available() and cfg.enable_tf32:
            torch.backends.cuda.matmul.allow_tf32 = True

    # ------------- Public entry points -------------

    @torch.inference_mode()
    def run_closed_loop(self, model, tok, user_text: str) -> Dict[str, Any]:
        """
        Baseline -> up to max_attempts of (attribute → intervene → decode).
        Returns dict with baseline_text, steered_text, per-attempt metadata.
        """
        prompt = self._apply_chat_template_if_available(tok, user_text)
        if self.cfg.truncate_to is not None:
            prompt = self._truncate_prompt(tok, prompt, self.cfg.truncate_to)

        # Baseline decode (optional but often useful to store)
        baseline_text = self._generate_once(model, tok, prompt)

        attempts = []
        steered_text = baseline_text
        for t in range(1, self.cfg.max_attempts + 1):
            alpha_t = self.cfg.base_alpha * (1.0 + self.cfg.alpha_growth * (t - 1))

            # Detect (global top-K heads)
            selected = self._select_heads(model, tok, prompt)

            # Build per-layer interventions (mask + nullspace u)
            interventions = self._build_interventions(model, selected)

            # Apply intervention for one forward decode
            with self._nullspace_intervention(interventions, steer_scale=alpha_t):
                steered_text = self._generate_once(model, tok, prompt)

            attempts.append({
                "attempt": t,
                "alpha": alpha_t,
                "selected_heads": selected,
                "steered_text": steered_text,
            })

            # Early stopping hook for users: they may add success checks here
            # We return all attempts so main.py can decide when to stop.
            # if self._is_success(steered_text): break

        return {
            "baseline_text": baseline_text,
            "steered_text": steered_text,
            "attempts": attempts,
        }

    # ------------- Detection / attribution -------------

    @torch.inference_mode()
    def _select_heads(self, model, tok, prompt_text: str) -> List[Tuple[int, int]]:
        """
        Global top-K head selection.
        Optionally uses a fast proxy shortlist (target-logit drop) before exact KL.
        """
        # 1) Baseline distribution
        base_logits, _ = self._next_logits(model, tok, prompt_text)
        base_probs = self._softmax_clip(base_logits)
        # Optional proxy shortlist
        if self.cfg.use_proxy_preselection:
            shortlist = self._proxy_preselection(model, tok, prompt_text, k=self.cfg.topk_heads)
            # shortlist contains (delta_logit, L, H)
            per_layer = self._find_attn_stack(model, max_layers=self.cfg.max_layers_for_attr)
            # exact KL only on shortlisted heads
            exact_scores: List[Tuple[float, int, int]] = []
            for _, L, H in shortlist:
                d_kl = self._kl_for_head(model, tok, prompt_text, base_probs, L, H, per_layer)
                exact_scores.append((d_kl, L, H))
            exact_scores.sort(key=lambda x: x[0], reverse=True)
            return [(L, H) for _, L, H in exact_scores[: self.cfg.topk_heads]]

        # 2) Exact KL on all heads (slower)
        per_layer = self._find_attn_stack(model, max_layers=self.cfg.max_layers_for_attr)
        scores: List[Tuple[float, int, int]] = []
        for (layer_idx, _attn, out_proj, num_heads, head_dim, hidden) in per_layer:
            base_mask = torch.ones(hidden, device=out_proj.weight.device, dtype=out_proj.weight.dtype)
            for h in range(num_heads):
                d_kl = self._kl_for_head(model, tok, prompt_text, base_probs, layer_idx, h, per_layer)
                scores.append((d_kl, layer_idx, h))
        scores.sort(key=lambda x: x[0], reverse=True)
        return [(L, H) for _, L, H in scores[: self.cfg.topk_heads]]

    @torch.inference_mode()
    def _proxy_preselection(self, model, tok, prompt_text: str, k: int) -> List[Tuple[float, int, int]]:
        """
        Fast shortlist via target-logit drop (one masked head at a time).
        """
        base_logits, _ = self._next_logits(model, tok, prompt_text)
        y_star = int(base_logits.argmax().item())
        y_logit = float(base_logits[y_star].item())

        per_layer = self._find_attn_stack(model, max_layers=self.cfg.max_layers_for_attr)
        scores: List[Tuple[float, int, int]] = []
        for (layer_idx, _attn, out_proj, num_heads, head_dim, hidden) in per_layer:
            base_mask = torch.ones(hidden, device=out_proj.weight.device, dtype=out_proj.weight.dtype)
            for h in range(num_heads):
                start, end = h * head_dim, (h + 1) * head_dim
                mask_cols = base_mask.clone()
                mask_cols[start:end] = 0.0
                with self._mask_columns(out_proj, mask_cols):
                    logits_abl, _ = self._next_logits(model, tok, prompt_text)
                delta = y_logit - float(logits_abl[y_star].item())
                scores.append((delta, layer_idx, h))

        scores.sort(key=lambda x: x[0], reverse=True)
        shortlist_size = int(max(k, k * self.cfg.proxy_mult))
        return scores[:shortlist_size]

    @torch.inference_mode()
    def _kl_for_head(
        self,
        model,
        tok,
        prompt_text: str,
        base_probs: Tensor,
        L: int,
        H: int,
        per_layer: Sequence[Tuple[int, Any, Any, int, int, int]],
    ) -> float:
        """
        KL( P || P_masked^(L,H) ) where only head (L,H) is masked.
        """
        # find that layer's out_proj + dims
        out_proj, head_dim, hidden = None, None, None
        for (layer_idx, _attn, oop, num_heads, hd, hid) in per_layer:
            if layer_idx == L:
                out_proj, head_dim, hidden = oop, hd, hid
                break
        assert out_proj is not None

        # mask that single head
        base_mask = torch.ones(hidden, device=out_proj.weight.device, dtype=out_proj.weight.dtype)
        s, e = H * head_dim, (H + 1) * head_dim
        mask_cols = base_mask.clone()
        mask_cols[s:e] = 0.0
        with self._mask_columns(out_proj, mask_cols):
            logits_abl, _ = self._next_logits(model, tok, prompt_text)
        probs_abl = self._softmax_clip(logits_abl)
        # KL(P || Q) with clipping
        d_kl = float((base_probs * (base_probs.add(self.cfg.prob_clip_min).log() - probs_abl.add(self.cfg.prob_clip_min).log())).sum().item())
        return d_kl

    # ------------- Build interventions (mask + nullspace) -------------

    @torch.inference_mode()
    def _build_interventions(self, model, selected_heads: List[Tuple[int, int]]):
        """
        Returns: dict[layer_idx] -> {out_proj, mask_cols, u}
        """
        layer2heads: Dict[int, List[int]] = {}
        for L, Hidx in selected_heads:
            layer2heads.setdefault(L, []).append(Hidx)

        per_layer = self._find_attn_stack(model)
        interventions: Dict[int, Dict[str, Any]] = {}
        for (layer_idx, _attn, out_proj, num_heads, head_dim, hidden) in per_layer:
            if layer_idx not in layer2heads:
                continue

            heads_here = layer2heads[layer_idx]
            W = out_proj.weight
            device, dtype = W.device, W.dtype

            mask_cols = torch.ones(hidden, device=device, dtype=dtype)
            cols = []
            for h in heads_here:
                s, e = h * head_dim, (h + 1) * head_dim
                mask_cols[s:e] = 0.0
                cols.append(W[:, s:e])
            M = torch.cat(cols, dim=1) if cols else torch.zeros((hidden, 0), device=device, dtype=dtype)
            u = self._orthogonal_unit_vector_to_span(M)  # float32 QR, returns dtype of M

            # orthogonality guard (||M^T u||_∞ < tol)
            if M.numel() > 0:
                ok = (M.transpose(0, 1) @ u).abs().max().item() < self.cfg.orth_tol
                resamps = 0
                while (not ok) and (resamps < self.cfg.orth_resamples):
                    u = self._orthogonal_unit_vector_to_span(M)
                    ok = (M.transpose(0, 1) @ u).abs().max().item() < self.cfg.orth_tol
                    resamps += 1
                if not ok:
                    # Degenerate layer; skip steering on this layer.
                    continue

            interventions[layer_idx] = {"out_proj": out_proj, "mask_cols": mask_cols, "u": u}
        return interventions

    # ------------- Context managers: masking + steering hook -------------

    @contextlib.contextmanager
    def _mask_columns(self, out_proj, mask_cols: Tensor):
        with torch.no_grad():
            W = out_proj.weight
            original = W.data.clone()
            W.data.mul_(mask_cols.unsqueeze(0).to(W.device, W.dtype))
        try:
            yield
        finally:
            with torch.no_grad():
                out_proj.weight.data.copy_(original)

    def _make_steer_hook(self, u_vec: Tensor, scale: float):
        eps = self.cfg.eps_norm

        def hook(module, inputs, output):
            if not torch.is_tensor(output):
                return output
            out = output
            if out.dim() == 3:
                # (B, S, H) last token only
                with torch.no_grad():
                    act = out[:, -1, :]
                    act_norm = (act.pow(2).mean(dim=-1, keepdim=True).sqrt() + eps)
                    delta = (scale * act_norm) * u_vec.to(out.dtype)[None, :]
                    out = out.clone()
                    out[:, -1, :] = out[:, -1, :] + delta
                return out
            elif out.dim() == 2:
                # (S, H)
                with torch.no_grad():
                    act = out[-1, :]
                    act_norm = (act.pow(2).mean().sqrt() + eps)
                    delta = (scale * act_norm) * u_vec.to(out.dtype)
                    out = out.clone()
                    out[-1, :] = out[-1, :] + delta
                return out
            return out

        return hook

    @contextlib.contextmanager
    def _nullspace_intervention(self, interventions: Dict[int, Dict[str, Any]], steer_scale: float):
        mask_handles = []
        hook_handles = []
        try:
            for _, obj in interventions.items():
                out_proj, mask_cols, u = obj["out_proj"], obj["mask_cols"], obj["u"]

                # Mask selected heads by zeroing their out-proj columns
                W = out_proj.weight
                orig = W.data.clone()
                with torch.no_grad():
                    W.data.mul_(mask_cols.unsqueeze(0).to(W.device, W.dtype))
                mask_handles.append((out_proj, orig))

                # Add steer hook (nullspace)
                hook = out_proj.register_forward_hook(self._make_steer_hook(u, steer_scale))
                hook_handles.append((out_proj, hook))
            yield
        finally:
            for out_proj, orig in mask_handles:
                with torch.no_grad():
                    out_proj.weight.data.copy_(orig)
            for _, hook in hook_handles:
                try:
                    hook.remove()
                except Exception:
                    pass

    # ------------- Linear algebra helpers -------------

    def _orthogonal_unit_vector_to_span(self, M: Tensor) -> Tensor:
        """
        Sample u ⟂ span(M). Uses float32 thin-QR for stability, returns dtype of M.
        """
        H = M.shape[0]
        if M.numel() == 0:
            v = torch.randn(H, device=M.device, dtype=torch.float32)
            v = v / (v.norm() + self.cfg.eps_norm)
            return v.to(M.dtype)
        M32 = M.to(torch.float32)
        Q, _ = torch.linalg.qr(M32, mode="reduced")  # thin-QR
        r = torch.randn(H, device=M.device, dtype=torch.float32)
        proj = r - Q @ (Q.transpose(0, 1) @ r)
        n = proj.norm()
        if n < 1e-8:
            # Attempt re-draws; extremely rare unless M is near-full-rank
            proj = torch.randn_like(r)
            proj = proj - Q @ (Q.transpose(0, 1) @ proj)
            n = proj.norm()
            if n < 1e-8:
                proj = torch.randn_like(r)
                n = proj.norm()
        u = proj / (n + self.cfg.eps_norm)
        return u.to(M.dtype)

    # ------------- Model plumbing -------------

    def _find_attn_stack(
        self,
        model,
        max_layers: Optional[int] = None,
    ) -> List[Tuple[int, Any, Any, int, int, int]]:
        """
        Returns tuples per layer:
          (layer_idx, attn_module, out_proj_module, num_heads, head_dim, hidden_size)
        Supports GPT-NeoX (Pythia), OPT, and LLaMA/Qwen-like decoders.
        """
        infos = []
        cfg = model.config
        hidden_size = cfg.hidden_size

        def attn_info(attn_mod):
            out_proj_attr = None
            for name in ("o_proj", "out_proj", "dense"):
                if hasattr(attn_mod, name):
                    out_proj_attr = name
                    break
            if out_proj_attr is None:
                return None
            out_proj = getattr(attn_mod, out_proj_attr)

            for hn in ("num_attention_heads", "num_heads"):
                if hasattr(attn_mod, hn):
                    num_heads = int(getattr(attn_mod, hn))
                    break
            else:
                num_heads = int(getattr(cfg, "num_attention_heads", getattr(cfg, "num_heads", 0)))

            head_dim = hidden_size // num_heads
            return out_proj, num_heads, head_dim

        # GPT-NeoX (Pythia)
        if hasattr(model, "gpt_neox"):
            for i, layer in enumerate(model.gpt_neox.layers):
                attn = layer.attention
                x = attn_info(attn)
                if x:
                    infos.append((i, attn, *x, hidden_size))
            return infos if max_layers is None else [z for z in infos if z[0] < max_layers]

        # OPT
        if hasattr(model, "model") and hasattr(model.model, "decoder") and hasattr(model.model.decoder, "layers"):
            for i, layer in enumerate(model.model.decoder.layers):
                attn = layer.self_attn
                x = attn_info(attn)
                if x:
                    infos.append((i, attn, *x, hidden_size))
            return infos if max_layers is None else [z for z in infos if z[0] < max_layers]

        # LLaMA/Qwen style
        if hasattr(model, "model") and hasattr(model.model, "layers"):
            for i, layer in enumerate(model.model.layers):
                attn = layer.self_attn if hasattr(layer, "self_attn") else getattr(layer, "attention", None)
                if attn is None:
                    continue
                x = attn_info(attn)
                if x:
                    infos.append((i, attn, *x, hidden_size))
            return infos if max_layers is None else [z for z in infos if z[0] < max_layers]

        raise RuntimeError("Unsupported architecture: could not find decoder attention stack.")

    @torch.inference_mode()
    def _next_logits(self, model, tok, text: str) -> Tuple[Tensor, Dict[str, Tensor]]:
        inputs = tok(text, return_tensors="pt").to(model.device)
        out = model(**inputs, use_cache=self.cfg.use_cache)
        return out.logits[0, -1, :], inputs

    def _softmax_clip(self, logits: Tensor) -> Tensor:
        probs = torch.softmax(logits, dim=-1)
        return torch.clamp(probs, min=self.cfg.prob_clip_min, max=1.0)

    def _generate_once(self, model, tok, prompt: str) -> str:
        inputs = tok(prompt, return_tensors="pt").to(model.device)
        out_ids = model.generate(
            **inputs,
            max_new_tokens=self.cfg.max_new_tokens,
            do_sample=True,
            temperature=self.cfg.temperature,
            top_p=self.cfg.top_p,
            use_cache=self.cfg.use_cache,  # keep False for correctness under masking
            pad_token_id=tok.pad_token_id,
            eos_token_id=tok.eos_token_id,
        )
        return tok.decode(out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True).strip()

    def _apply_chat_template_if_available(self, tokenizer, user_text: str) -> str:
        fn = getattr(tokenizer, "apply_chat_template", None)
        if callable(fn):
            messages = [{"role": "user", "content": user_text}]
            try:
                return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            except Exception:
                pass
        return user_text

    def _truncate_prompt(self, tok, text: str, max_len: int) -> str:
        ids = tok(text, return_tensors="pt", add_special_tokens=False)["input_ids"][0]
        if ids.numel() <= max_len:
            return text
        trimmed = tok.decode(ids[-max_len:], skip_special_tokens=False)
        return trimmed
