# linear_rls.py
from __future__ import annotations
import torch, torch.nn as nn
from dataclasses import dataclass
from typing import Dict, List, Set


@dataclass
class _St:
    P: torch.Tensor          # (k_live,k_live)  fp32
    W: torch.Tensor          # (r , k_live)     bf16


class LinearRLSReconstructor:
    """Streaming ridge LS for a column-prunable nn.Linear."""

    # ---------------------------------------------------------------- init
    def __init__(self,
                 student_layer: nn.Linear,
                 block_size: int = 1,
                 lambda_ridge: float = 1e-4):
        self.s_layer = student_layer
        self.bs      = block_size
        self.lam     = lambda_ridge
        self._state: Dict[int, _St] = {}          # blk → state
        self._blk_cols: Dict[int, List[int]] = {} # blk → column list
        self._pruned_cols: Set[int] = set()
        self._hook  = None                        # teacher forward-hook

    # -------- live block registration (HALPE calls once) -------------------
    def set_live_blocks(self,
                        blk_to_cols: Dict[int, List[int]],
                        r_dim: int, device):
        """Register *all currently live* blocks (candidate+non-candidate)."""
        self._blk_cols = blk_to_cols
        eye = torch.eye(self.bs, device=device, dtype=torch.float32)
        for blk, cols in blk_to_cols.items():
            k_live = len(cols)
            self._state[blk] = _St(
                P = eye[:k_live, :k_live] / self.lam,
                W = torch.zeros(r_dim, k_live,
                                device=device, dtype=torch.bfloat16))

    def set_pruned_cols(self, pruned_cols: List[int]):
        """Tell the reconstructor which columns have become permanently zero."""
        self._pruned_cols = set(pruned_cols)

    # -------- start / stop --------------------------------------------------
    def start(self, teacher_layer: nn.Linear):
        assert self._hook is None, "hook already active"
        self._hook = teacher_layer.register_forward_hook(self._fwd_hook)

    def stop_and_flush(self):
        """Flush W* into the student layer and remove hook."""
        if self._hook: self._hook.remove(); self._hook = None
        for blk, st in self._state.items():
            live = [c for c in self._blk_cols[blk] if c not in self._pruned_cols]
            if live:
                self.s_layer.weight.data[:, live] = \
                    st.W.to(self.s_layer.weight.dtype)[:, :len(live)]
        self._state.clear()

    # -------- internal forward-hook ----------------------------------------
    def _fwd_hook(self, _mod, inp, out):
        X_full = inp[0].detach()     # (B , C_in)
        Y_full = out.detach()        # (B , C_out)  (teacher output)
        for blk, st in self._state.items():
            cols_all  = self._blk_cols[blk]
            live_cols = [c for c in cols_all if c not in self._pruned_cols]
            if not live_cols:
                continue
            Xb = X_full[:, live_cols]            # (B , k_live)
            Yb = Y_full[:, live_cols]            # same slice  (B , k_live)
            st.P, st.W = self._rls_update(st.P, st.W, Xb, Yb)

    # -------- rank-B Woodbury update ---------------------------------------
    @staticmethod
    @torch.inference_mode()
    def _rls_update(P, W, Xb, Yb):
        A = Xb @ P
        G = 1.0 / (1.0 + (A * Xb).sum(1, keepdim=True))   # (B,1)
        K = (P @ Xb.T) * G.T                              # (k,B)
        W += (Yb.T - W @ Xb.T) @ K.T                      # (r,k)
        P -= K @ A                                        # (k,k)
        return P, W
    