
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

grad_norm_clip = 5.0

class DFA:
    def __init__(self, model: nn.Module):
        self.device  = next(model.parameters()).device
        self.params  = [p for p in model.parameters() if p.requires_grad]

        # --- work out C = #classes -----------------------------------------
        # (1) safest: read it from the final linear layer
        if hasattr(model, "layers"):
            out_dim = model.layers[-1].out_features      # works for most MLPs/CNNs
        else:
            # (2) fall‑back: grab the *last bias* vector
            last_bias = next(p for p in reversed(self.params) if p.dim() == 1)
            out_dim = last_bias.shape[0]
        # -------------------------------------------------------------------

        self.shapes = [p.shape   for p in self.params]
        self.sizes  = [p.numel() for p in self.params]

        # build one fixed projection matrix per parameter
        self.B = [torch.randn(out_dim, sz, device=self.device)
                  * math.sqrt(1.0 / out_dim)
                  for sz in self.sizes]

        # sanity‑check while developing
        for Bk in self.B:
            assert Bk.shape[0] == out_dim, "B‑matrix has wrong leading dimension"

    def step(self, _data: torch.Tensor, error_signal): #flipped
        """
        error_signal: ∂L/∂logits, shape [batch, out_dim]
        _data unused for DFA
        """
        with torch.no_grad():
            batch = error_signal.size(0)
            # for each model parameter, predict its update
            for p, Bk, shape, sz in zip(self.params, self.B, self.shapes, self.sizes):
                # [batch, out_dim] @ [out_dim, sz] → [batch, sz]
                pred = error_signal @ Bk
                # average across batch, reshape → update
                dw   = pred.mean(dim=0).view(shape) 
                if dw.norm() > grad_norm_clip:
                    dw *= grad_norm_clip/(dw.norm()+1e-12)
                p.grad = dw

class RandomNN:
    def __init__(self, model: nn.Module, proj_dim: int = 4):
        self.device = next(model.parameters()).device
        self.model = model
        self.proj_dim = proj_dim

        # ─ model parameter bookkeeping ──────────────────────────────────────
        self.params  = [p for p in model.parameters() if p.requires_grad]
        self.Q = [torch.randn_like(p) for p in model.parameters() if p.requires_grad]

        self.shapes  = [tuple(p.shape) for p in self.params]
        self.sizes   = [p.numel() for p in self.params]

        # ─ lazy‑initialised buffers ─────────────────────────────────────────
        self.P: nn.Module | None = None             # now an MLP instead of a projection
        self.B:     list[torch.Tensor] = []                 # z     → Δθₖ
        self.in_dim: int | None = None


    # ---------------------------------------------------------------------
    # internal helpers
    # ---------------------------------------------------------------------
    def _lazy_init(self, in_dim: int) -> None:
        """Initialise the *fixed* random projection matrices once input size is known."""
        self.in_dim = in_dim

        # Input projection P  (d × D)
        self.P = nn.Sequential(
            nn.Linear(in_dim, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, 128),
            nn.Tanh(),
            nn.Linear(128, self.proj_dim),
            nn.Tanh()
        ).to(self.device)

        for param in self.P.parameters():
            param.requires_grad = False  # keep it fixed like the original random projection
        # One Bₖ per parameter (d × sₖ)
        for s_k in self.sizes:
            Bk = torch.randn(self.proj_dim, s_k, device=self.device)
            nn.init.normal_(Bk, mean=0.0, std=1.0 / math.sqrt(self.proj_dim))
            Bk.requires_grad = False
            self.B.append(Bk)

    # ---------------------------------------------------------------------
    # public API – called every training step
    # ---------------------------------------------------------------------
    @torch.no_grad()
    def step(self, data: torch.Tensor, _error_signal: torch.Tensor | None = None):
        """Compute parameter updates for **one minibatch**.

        Parameters
        ----------
        data : torch.Tensor
            The raw minibatch (images, token IDs, …). Will be flattened.
        _error_signal : ignored
            Present for API compatibility with DFA.
        """
        # ──────────────────────────────────────────────────────────────────
        # 1. Flatten inputs
        if data.dtype == torch.long:
            # For language models using embeddings
            data = data.to(self.device)
            emb  = self.model.embed(data)                                 # (B, S, D)

            B, S = data.size()
            positions = torch.arange(S, device=self.device).unsqueeze(0).expand(B, S)  # (B, S)
            pos_emb = self.model.pos_embed(positions)                    # (B, S, D)

            emb  = emb + pos_emb
            x    = emb.view(B, -1)                                       # (B, S·D)
        else:
            x = data.to(self.device).view(data.size(0), -1)               # (B, D)

        # 2. Lazy initialisation
        if self.P is None:
            self._lazy_init(x.size(1))

        # 3. Low‑dimensional projection: z ∈ ℝ^{B×d}
        z = self.P(x)  
        z = F.tanh(z)                                               # (B, d)

        # 4. Predicted updates for each parameter
        for p, Bk, shape, Q in zip(self.params, self.B, self.shapes,self.Q):
            # (B, d) @ (d, sₖ) → (B, sₖ)
            pred = z @ Bk
            dw   = pred.mean(dim=0).view(shape)                       # (shape of p)
            
            incA = (p.data-dw).norm() > p.data.norm()
            incA = 2*(incA*1.0 - 0.5)

            wnorm = p.data.norm() #.item()    
            scale = (100.0 - wnorm).sign()

            dw = dw * scale * incA
            dwnorm = dw.norm().double().item()
            if dwnorm > 1:

                dw = 1 * dw / (dwnorm + 1e-8)

            p.grad = dw