# helpers.py
# ----------
"""Shared utilities for Hebbian‐alignment experiments.

▪ CNN support **removed** – only **mlp**, **regression-mlp**, **transformer**  
▪ Student-teacher regression auto-generated for the two new kinds  
▪ **Recursion / weak-ref bug fixed** – `_TransformerBlockCache` keeps a
  `weakref.ref` to its parent, which is never registered as a sub-module,
  eliminating the "maximum recursion depth" and "unhashable weakref" errors.
"""

from __future__ import annotations

import time, random, weakref
from pathlib import Path
from typing import Dict, Iterable, List, Tuple
import copy

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from sklearn.mixture import GaussianMixture
from torch.utils.data import DataLoader, Subset, TensorDataset
from torchvision import datasets, transforms
import yaml 

device = "cuda" if torch.cuda.is_available() else "cpu"

print("Using device:", device)

def _load_default_cfg() -> Dict:
    yaml_path = Path(__file__).with_name("experiments.yaml")
    if not yaml_path.exists():
        raise FileNotFoundError(
            f"Required file 'experiments.yaml' not found at {yaml_path.resolve()}"
        )
    with open(yaml_path, "r") as fp:
        data = yaml.safe_load(fp)
    if "default_config" not in data:
        raise KeyError("The key 'default_config' is missing from experiments.yaml")
    return data["default_config"]

DEFAULT_CFG: Dict = _load_default_cfg()

def _flatten_last(t: torch.Tensor) -> torch.Tensor:
    """Merge all but the last dimension so that .t() works for ≥3-D tensors."""
    if t.dim() <= 2:
        return t
    return t.reshape(-1, t.shape[-1])          # (B × S, d)

def pre_synaptic_inputs(x_batch: torch.Tensor, model: nn.Module, cfg: Dict) -> List[torch.Tensor]:
    """Return the list of inputs to each layer (MLP & Transformer supported)."""
    if cfg["model"] == "transformer":
        # the head sees the pooled vector as its first input:
        # it's cached as the *last* entry in cached_act_tr
        h = model.cached_act_tr[-1]
        return [h] + model.cached_act

    if isinstance(model, SmallCNN):

        # For CNN models, cached_act now contains: [flattened_conv_output, fc1_activated_output, fc2_activated_output]
        # cached_pre contains: [fc1_pre_output, fc2_pre_output, fc3_pre_output]
        # We need: [flattened_conv_output, fc1_activated_output, fc2_activated_output] for hebbian_update
        return model.cached_act
    else:
        # MLP models
        x0 = x_batch.view(x_batch.size(0), -1)
        return [x0] + model.cached_act  # len == n_layers

def pretty_cfg_diff(cfg: Dict, default: Dict) -> str:
    diff_items = [(k, cfg[k]) for k in sorted(cfg) if k not in default or cfg[k] != default[k]]
    if not diff_items:
        return "default"
    return "__".join(f"{k}={v}" for k, v in diff_items if not isinstance(v, (list, dict)))

def ensure_dirs(root: Path) -> Dict[str, Path]:
    root.mkdir(parents=True, exist_ok=True)
    sub = {}
    for name in ("checkpoints", "metrics", "figures"):
        p = root / name; p.mkdir(exist_ok=True); sub[name] = p
    return sub



# -----------------------------------------------------------------------------
#  Student-teacher regression data --------------------------------------------
# -----------------------------------------------------------------------------

def generate_CIFAR_data(cfg):
    global device
    batch_size = cfg["batch_size"]
    common = dict(num_workers=2, pin_memory=(device == "cuda"))

    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.2470, 0.2435, 0.2616))
    ])
    train_ds = datasets.CIFAR10("./data", True,  download=True, transform=transform)
    test_ds  = datasets.CIFAR10("./data", False, download=True, transform=transform)
    metrics_subset = Subset(train_ds, torch.randperm(len(train_ds))[:1000])

    train_loader   = DataLoader(train_ds,       batch_size, shuffle=True,  **common)
    metrics_loader = DataLoader(metrics_subset, 1,          shuffle=True,  **common)
    test_loader    = DataLoader(test_ds,        batch_size, shuffle=False, **common)
    return train_loader, test_loader, metrics_loader


def calculate_mlp_params(layer_sizes: List[int], biases: bool = False) -> int:
    """Calculate the number of parameters in an MLP with given layer sizes."""
    total_params = 0
    for i in range(len(layer_sizes) - 1):
        # Weight parameters
        total_params += layer_sizes[i] * layer_sizes[i + 1]
        # Bias parameters (if enabled)
        if biases:
            total_params += layer_sizes[i + 1]
    return total_params

def calculate_cnn_params(input_channels: int = 3, num_classes: int = 10) -> int:
    """Calculate parameters for a small CNN architecture."""
    # Small CNN: Conv(3, 32, 3x3) -> Conv(32, 64, 3x3) -> Conv(64, 128, 3x3) -> FC(128*4*4, 512) -> FC(512, 256) -> FC(256, num_classes)
    # Assuming 32x32 input (CIFAR-10)
    conv1_params = 3 * 32 * 3 * 3 + 32  # weights + bias
    conv2_params = 32 * 64 * 3 * 3 + 64
    conv3_params = 64 * 128 * 3 * 3 + 128
    fc1_params = 128 * 4 * 4 * 512 + 512  # After 3 conv layers with max pool, 32x32 -> 4x4
    fc2_params = 512 * 256 + 256  # New intermediate layer
    fc3_params = 256 * num_classes + num_classes  # Final layer
    return conv1_params + conv2_params + conv3_params + fc1_params + fc2_params + fc3_params

def get_layer_sizes_for_params(target_params: int, input_size: int = 3072, output_size: int = 10, 
                              biases: bool = False) -> List[int]:
    """Calculate layer sizes to achieve approximately target_params parameters."""
    # Create deeper networks with specified layer depths
    # 100k: 1 hidden layer, 1M: 2 hidden layers, 10M: 6 hidden layers, 50M: 8 hidden layers, 100M: 14 hidden layers, 1B: 29 hidden layers
    
    if target_params <= 100_000:  # 100k params - 1 hidden layer
        # 3 layers total: input -> hidden -> output
        # Parameters: input_size * hidden_size + hidden_size * output_size ≈ target
        # Linear equation: (input_size + output_size) * hidden_size ≈ target
        hidden_size = int(target_params / (input_size + output_size))
        hidden_size = max(hidden_size, 64)
        layer_sizes = [input_size, hidden_size, output_size]
        
    elif target_params <= 1_000_000:  # 1M params - 2 hidden layers
        # 4 layers total: input -> hidden -> hidden -> output
        # Parameters: input_size * h + h^2 + h * output_size ≈ target
        # Quadratic: h^2 + (input_size + output_size) * h - target ≈ 0
        a = 1
        b = input_size + output_size
        c = -target_params
        hidden_size = int((-b + (b**2 - 4*a*c)**0.5) / (2*a))
        hidden_size = max(hidden_size, 128)
        layer_sizes = [input_size, hidden_size, hidden_size, output_size]
        
    elif target_params <= 10_000_000:  # 10M params - 6 hidden layers
        # 8 layers total: input -> 6 hidden -> output
        # Parameters: input_size * h + 5h^2 + h * output_size ≈ target
        # Quadratic: 5h^2 + (input_size + output_size) * h - target ≈ 0
        a = 5
        b = input_size + output_size
        c = -target_params
        hidden_size = int((-b + (b**2 - 4*a*c)**0.5) / (2*a))
        hidden_size = max(hidden_size, 256)
        layer_sizes = [input_size] + [hidden_size] * 6 + [output_size]
        
    elif target_params <= 50_000_000:  # 50M params - 8 hidden layers
        # 10 layers total: input -> 8 hidden -> output
        # Parameters: input_size * h + 7h^2 + h * output_size ≈ target
        # Quadratic: 7h^2 + (input_size + output_size) * h - target ≈ 0
        a = 7
        b = input_size + output_size
        c = -target_params
        hidden_size = int((-b + (b**2 - 4*a*c)**0.5) / (2*a))
        hidden_size = max(hidden_size, 384)
        layer_sizes = [input_size] + [hidden_size] * 8 + [output_size]
        
    elif target_params <= 100_000_000:  # 100M params - 14 hidden layers
        # 16 layers total: input -> 14 hidden -> output
        # Parameters: input_size * h + 13h^2 + h * output_size ≈ target
        # Quadratic: 13h^2 + (input_size + output_size) * h - target ≈ 0
        a = 13
        b = input_size + output_size
        c = -target_params
        hidden_size = int((-b + (b**2 - 4*a*c)**0.5) / (2*a))
        hidden_size = max(hidden_size, 512)
        layer_sizes = [input_size] + [hidden_size] * 14 + [output_size]
        
    else:  # 1B params - 29 hidden layers
        # 31 layers total: input -> 29 hidden -> output
        # Parameters: input_size * h + 28h^2 + h * output_size ≈ target
        # Quadratic: 28h^2 + (input_size + output_size) * h - target ≈ 0
        a = 28
        b = input_size + output_size
        c = -target_params
        hidden_size = int((-b + (b**2 - 4*a*c)**0.5) / (2*a))
        hidden_size = max(hidden_size, 1024)
        layer_sizes = [input_size] + [hidden_size] * 29 + [output_size]
    
    # Fine-tune to get closer to target
    actual_params = calculate_mlp_params(layer_sizes, biases)
    if actual_params < target_params * 0.8:  # Too small
        # Increase hidden layer sizes
        scale_factor = (target_params / actual_params) ** 0.5
        for i in range(1, len(layer_sizes) - 1):
            layer_sizes[i] = int(layer_sizes[i] * scale_factor)
    elif actual_params > target_params * 1.2:  # Too large
        # Decrease hidden layer sizes
        scale_factor = (target_params / actual_params) ** 0.5
        for i in range(1, len(layer_sizes) - 1):
            layer_sizes[i] = int(layer_sizes[i] * scale_factor)
    
    return layer_sizes

def get_dataloaders_model_crit_model(cfg):

    kind = cfg["model"].lower()
    
    # Handle frozen weights/sparsity: scale layer sizes first, then create model
    frozen_ratio = cfg.get("frozen_ratio", 0.0)
    sparsity = cfg.get("sparsity", 0.0)
    original_layer_sizes = cfg["layer_sizes"]
    scaling_ratio = max(frozen_ratio, sparsity)
    
    # Handle different model types based on parameter count
    if kind == "small-cnn":
        # Use CNN architecture with approximately the same parameters as base MLP
        base_mlp_params = calculate_mlp_params(original_layer_sizes, cfg.get("biases", False))
        # For CNN, we don't need to scale layer sizes as it has its own architecture
        scaled_layer_sizes = original_layer_sizes
    elif kind == "mlp-100k":
        target_params = 100_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    elif kind == "mlp-1m":
        target_params = 1_000_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    elif kind == "mlp-10m":
        target_params = 10_000_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    elif kind == "mlp-50m":
        target_params = 50_000_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    elif kind == "mlp-100m":
        target_params = 100_000_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    elif kind == "mlp-1b":
        target_params = 1_000_000_000
        scaled_layer_sizes = get_layer_sizes_for_params(target_params, 
                                                       input_size=original_layer_sizes[0], 
                                                       output_size=original_layer_sizes[-1],
                                                       biases=cfg.get("biases", False))
    else:
        # Use the larger of frozen_ratio or sparsity for scaling
        
        
        if scaling_ratio > 0.0:
            scale_factor = 1.0 / (1.0 - scaling_ratio)
            scaled_layer_sizes = []
            for i, size in enumerate(original_layer_sizes):
                if i == 0:  # Input layer - keep original size
                    scaled_layer_sizes.append(size)
                else:  # Hidden and output layers - scale up
                    scaled_size = int(size * scale_factor)
                    scaled_layer_sizes.append(scaled_size)
        else:
            scaled_layer_sizes = original_layer_sizes
    
    if kind in {"mlp", "regression-mlp", "mlp-100k", "mlp-1m", "mlp-10m", "mlp-50m", "mlp-100m", "mlp-1b"}:
        model = GeneralMLP(scaled_layer_sizes, cfg["activation"],
                          initialization=cfg["initialization"],regularization_mode=cfg["regularization_mode"],residual=cfg["residual"],biases=cfg["biases"])

    elif kind == "small-cnn":
        model = SmallCNN(scaled_layer_sizes, cfg["activation"],
                        initialization=cfg["initialization"], regularization_mode=cfg["regularization_mode"], 
                        residual=cfg["residual"], biases=cfg["biases"])

    elif kind == "transformer":
        model = SimpleTransformer(scaled_layer_sizes, vocab_size=cfg["vocab_size"],regularization_mode=cfg["regularization_mode"], biases=cfg["biases"])

    # Apply frozen/sparsity masks if either is > 0
    if scaling_ratio > 0.0:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if param.dim() > 0:  # Only freeze weight matrices, not biases
                    # Find the module that contains this parameter
                    module_name = '.'.join(name.split('.')[:-1])
                    if module_name:
                        module = dict(model.named_modules())[module_name]
                    else:
                        module = model
                    
                    mask = torch.rand(param.shape) < scaling_ratio
                    module.register_buffer(f'{name.split(".")[-1]}_frozen_mask', mask.to(param.device))
                    
                    # If sparsity is enabled, set frozen weights to 0
                    if sparsity > 0:
                        param.data[mask] = 0.0

    if kind in {"regression-mlp", "transformer"}:

        train_data, test_data, metrics_data = generate_teacher_data(model, cfg)
        crit = torch.nn.MSELoss()
    else:
        # All other models (MLP, CNN) use CIFAR-10 data
        train_data, test_data, metrics_data = generate_CIFAR_data(cfg)
        crit = torch.nn.CrossEntropyLoss()

    return model, (train_data, test_data, metrics_data), crit

    

def generate_teacher_data(model: nn.Module, cfg: Dict
                         ) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Auto-generate regression data from a cloned, freshly-initialized teacher model.

    Returns:
        train_loader, test_loader, metrics_loader
    """
    global device
    # 1) Clone & re-init teacher
    teacher = copy.deepcopy(model)
    # reset_parameters is defined on Linear, Embedding, LayerNorm, etc.
    def _reset(m):
        if hasattr(m, "reset_parameters"):
            m.reset_parameters()
    teacher.apply(_reset)

    teacher.to(device)
    teacher.eval()

    # 2) Determine sizes
    n_train   = cfg.get("n_train",   20_000)
    n_test    = cfg.get("n_test",     2_000)
    n_metrics = cfg.get("n_metrics",    500)
    batch_sz  = cfg["batch_size"]
    common    = dict(num_workers=2, pin_memory=(device == "cuda"))

    total = n_train + n_test + n_metrics

    # 3) Build inputs
    if isinstance(teacher, SimpleTransformer):
        seq_len    = cfg["max_len"]
        vocab_size = cfg["vocab_size"]
        # avoid padding_idx=0 if you want all tokens random
        X = torch.randint(1, vocab_size, (total, seq_len))
    elif isinstance(teacher, GeneralMLP):
        input_dim = cfg["layer_sizes"][0]
        X = torch.randn(total, input_dim )

    else:
        raise ValueError(f"Unsupported model type {type(teacher)} for teacher data")

    # 4) Compute targets
    t_batch = cfg["batch_size"] * 4
    preds   = []
    with torch.no_grad():
        for start in range(0, total, t_batch):
            end   = min(start + t_batch, total)
            x_blk = X[start:end]
            
            y_blk = teacher(x_blk.to(device)).cpu()

            preds.append(y_blk)

    Y = torch.cat(preds, dim=0)           # shape: (total, output_dim)
    # 5) Split
    X_train, Y_train     = X[:n_train],     Y[:n_train]
    X_test,  Y_test      = X[n_train:n_train+n_test], Y[n_train:n_train+n_test]
    X_metrics, Y_metrics = X[-n_metrics:],    Y[-n_metrics:]

    # 6) Wrap in TensorDataset + DataLoader
    train_ds   = TensorDataset(X_train, Y_train)
    test_ds    = TensorDataset(X_test,  Y_test)
    metrics_ds = TensorDataset(X_metrics, Y_metrics)

    train_loader   = DataLoader(train_ds,   batch_sz, shuffle=True,  **common)
    test_loader    = DataLoader(test_ds,    batch_sz, shuffle=False, **common)
    metrics_loader = DataLoader(metrics_ds, 1,       shuffle=True,  **common)

    return train_loader, test_loader, metrics_loader
    
# -----------------------------------------------------------------------------
#  Models ---------------------------------------------------------------------
# -----------------------------------------------------------------------------



class GeneralMLP(nn.Module):
    def __init__(self, layer_sizes: List[int], activation: str = "relu",
                 store_activations: bool = True, initialization: str = "default", regularization_mode= "", residual=False, biases=False):
        super().__init__()
        self.store_activations = store_activations
        acts = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(),
                "linear": nn.Identity(), "elu": nn.ELU()}
        if activation not in acts:
            raise ValueError(f"Unsupported activation '{activation}'.")
        self.act_fn = acts[activation]

        self.layers = nn.ModuleList(
            nn.Linear(layer_sizes[i], layer_sizes[i + 1], bias=biases)
            for i in range(len(layer_sizes) - 1)
        )
        self.n_hidden = len(self.layers) - 1
        self.residual = residual
        self.regularization_mode = regularization_mode
        self.biases = biases

        # BatchNorm for hidden layers only
        if self.regularization_mode in ["batch_norm","L2_weight_decay_bn"]:
            self.batch_norms = nn.ModuleList(
                nn.BatchNorm1d(layer_sizes[i + 1]) for i in range(len(layer_sizes) - 2)
            )
        else:
            self.batch_norms = None

        # Dropout for hidden layers only
        if self.regularization_mode == "drop_out":
            self.dropouts = nn.ModuleList(
                nn.Dropout(p=0.5) for _ in range(len(layer_sizes) - 2)
            )
        else:
            self.dropouts = None

        scale = {"default": 1.0, "high": 2.0, "low": 0.5}[initialization]

        for param in self.parameters():
            param.data.mul_(scale)
        
    @staticmethod
    def _cache(t: torch.Tensor, store: List[torch.Tensor],
               grad_store: List[torch.Tensor]) -> torch.Tensor:
        if not t.requires_grad:
            t = t.clone().requires_grad_()
        store.append(t); t.register_hook(lambda g: grad_store.append(g))
        return t


    def forward(self, x: torch.Tensor):
        x = x.view(x.size(0), -1)
        if self.store_activations:
            self.cached_pre, self.cached_act = [], []
            self.cached_pre_grad, self.cached_act_grad = [], []
            
        for i, layer in enumerate(self.layers):
            pre = F.linear(x, layer.weight, bias=None)
            if self.store_activations:
                pre = self._cache(pre, self.cached_pre, self.cached_pre_grad)
            if self.residual and x.size(1) == pre.size(1):
                x = x + pre

                
            else:
                x = pre
            if self.biases:
                x += layer.bias

            if i < self.n_hidden:
                # Batch norm if applicable
                if self.regularization_mode in ["batch_norm", "L2_weight_decay_bn"]:
                    x = self.batch_norms[i](x)

                x = self.act_fn(x)

                # Dropout if applicable
                if self.regularization_mode == "drop_out":
                    x = self.dropouts[i](x)

                if self.store_activations:
                    x = self._cache(x, self.cached_act, self.cached_act_grad)

        return x

class SimpleTransformer(GeneralMLP):
    """
    A small Transformer-based regressor.  Inputs are integer token sequences;
    outputs are continuous vectors (regression).  Inherits GeneralMLP for the
    final MLP head and its activation-caching machinery, and adds its own
    caches for the transformer layers.
    """
    def __init__(
        self,
        layer_sizes: List[int],
        *,
        vocab_size: int,
        max_len: int = 32,
        n_heads: int = 4,
        n_layers: int = 2,
        ff_dim: int | None = None,
        dropout: float = 0.0,
        store_activations: bool = True,
        initialization: str = "default",
        regularization_mode = "",
        biases = False
    ):
        """
        Args:
            layer_sizes: list of sizes for the MLP head, including input
                dim (== d_model) and final output dim.
            vocab_size: size of token vocabulary.
            max_len: maximum sequence length (for positional embeddings).
            n_heads: number of attention heads.
            n_layers: number of TransformerEncoderLayer blocks.
            ff_dim: dimensionality of the inner feed-forward in each block
                (defaults to 4*d_model if None).
            dropout: dropout probability in attention blocks.
            store_activations: whether to cache pre/post activations.
            initialization: scaling factor for head weights ("default","high","low").
        """
        # embedding dimension is layer_sizes[0]
        d_model = layer_sizes[0]
        super().__init__(
            layer_sizes,
            activation="tanh",           # head is pure linear (regression)
            store_activations=store_activations,
            initialization=initialization,
            regularization_mode=regularization_mode,
            biases=biases
        )

        self.vocab_size = vocab_size
        self.max_len    = max_len
        self.d_model    = d_model

        # token + positional embeddings
        self.embed     = nn.Embedding(vocab_size, d_model)
        self.pos_embed = nn.Embedding(max_len, d_model)
        # feed-forward inner size
        ff_dim = ff_dim or d_model #(4 * d_model)

        # build transformer encoder stack
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=ff_dim,
            dropout=dropout,
            activation="relu",
            batch_first=True,           # work in (B, S, D) space
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        x: LongTensor of shape (B, S), token indices in [0, vocab_size).
        returns: FloatTensor of shape (B, output_dim).
        """
        global device
        B, S = x.shape
        if S > self.max_len:
            raise ValueError(f"Sequence length {S} > max_len {self.max_len}")

        # --- transformer caching ---
        if self.store_activations:
            self.cached_pre_tr, self.cached_act_tr = [], []
            self.cached_pre_grad_tr, self.cached_act_grad_tr = [], []

        # 1) embed tokens + add pos embeddings
        token_emb = self.embed(x)                        # (B, S, D)
        positions = torch.arange(S, device=x.device)     # (S,)
        positions = positions.unsqueeze(0).expand(B, S)  # (B, S)
        pos_emb   = self.pos_embed(positions)            # (B, S, D)
        t = token_emb + pos_emb                          # (B, S, D)
        if self.store_activations:
            t = self._cache(t, self.cached_pre_tr, self.cached_pre_grad_tr)

        # 2) transformer layers
        t = self.transformer(t)                            # (B, S, D)
        if self.store_activations:
            t = self._cache(t, self.cached_act_tr, self.cached_act_grad_tr)

        # 3) simple pooling (mean over sequence)
        h = t.mean(dim=1)                                  # (B, D)
        if self.store_activations:
            h = self._cache(h, self.cached_act_tr, self.cached_act_grad_tr)
        
        h = h/(h.norm(dim=1, keepdim=True) + 1e-12)
        out = super().forward(h)                           # (B, output_dim)
        return out


class SmallCNN(nn.Module):
    """
    A small CNN for CIFAR-10 classification with approximately the same number of parameters
    as the base MLP (around 400k parameters).
    """
    def __init__(self, layer_sizes: List[int], activation: str = "relu",
                 store_activations: bool = True, initialization: str = "default", 
                 regularization_mode: str = "", residual: bool = False, biases: bool = False):
        super().__init__()
        self.store_activations = store_activations
        self.regularization_mode = regularization_mode
        self.biases = biases
        
        acts = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid(), "tanh": nn.Tanh(),
                "linear": nn.Identity(), "elu": nn.ELU()}
        if activation not in acts:
            raise ValueError(f"Unsupported activation '{activation}'.")
        self.act_fn = acts[activation]
        
        # CNN layers
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1, bias=biases)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1, bias=biases)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1, bias=biases)
        
        # Pooling layers
        self.pool = nn.MaxPool2d(2, 2)
        
        # Calculate the size after convolutions and pooling
        # Input: 32x32x3 -> Conv1 -> 32x32x32 -> Pool -> 16x16x32
        # -> Conv2 -> 16x16x64 -> Pool -> 8x8x64  
        # -> Conv3 -> 8x8x128 -> Pool -> 4x4x128
        self.fc1 = nn.Linear(128 * 4 * 4, 512, bias=biases)
        self.fc2 = nn.Linear(512, 256, bias=biases)  # Added intermediate layer
        self.fc3 = nn.Linear(256, 10, bias=biases)  # 10 classes for CIFAR-10
        
        # BatchNorm layers if enabled
        if self.regularization_mode in ["batch_norm", "L2_weight_decay_bn"]:
            self.bn1 = nn.BatchNorm2d(32)
            self.bn2 = nn.BatchNorm2d(64)
            self.bn3 = nn.BatchNorm2d(128)
        else:
            self.bn1 = self.bn2 = self.bn3 = None
            
        # Dropout if enabled
        if self.regularization_mode == "drop_out":
            self.dropout = nn.Dropout(p=0.5)
        else:
            self.dropout = None
            
        # Initialize weights
        scale = {"default": 1.0, "high": 2.0, "low": 0.5}[initialization]
        for param in self.parameters():
            param.data.mul_(scale)
    
    @staticmethod
    def _cache(t: torch.Tensor, store: List[torch.Tensor],
               grad_store: List[torch.Tensor]) -> torch.Tensor:
        if not t.requires_grad:
            t = t.clone().requires_grad_()
        store.append(t); t.register_hook(lambda g: grad_store.append(g))
        return t
    
    def forward(self, x: torch.Tensor):
        if self.store_activations:
            self.cached_pre, self.cached_act = [], []
            self.cached_pre_grad, self.cached_act_grad = [], []
        
        # Reshape input if needed (CIFAR-10: 3x32x32)
        if x.dim() == 3:
            x = x.unsqueeze(0)
        if x.size(1) != 3:  # If channels not in second dimension
            x = x.view(x.size(0), 3, 32, 32)
        
        # Conv1
        x = self.conv1(x)
        if self.bn1:
            x = self.bn1(x)
        x = self.act_fn(x)
        x = self.pool(x)
        
        # Conv2
        x = self.conv2(x)
        if self.bn2:
            x = self.bn2(x)
        x = self.act_fn(x)
        x = self.pool(x)
        
        # Conv3
        x = self.conv3(x)
        if self.bn3:
            x = self.bn3(x)
        x = self.act_fn(x)
        x = self.pool(x)
        
        # Flatten
        x = x.view(x.size(0), -1)
        
        # Cache the flattened conv output as input to fc1
        if self.store_activations:
            x = self._cache(x, self.cached_act, self.cached_act_grad)
        
        # FC layers
        x = self.fc1(x)
        if self.store_activations:
            x = self._cache(x, self.cached_pre, self.cached_pre_grad)
        x = self.act_fn(x)
        if self.dropout:
            x = self.dropout(x)
        if self.store_activations:
            x = self._cache(x, self.cached_act, self.cached_act_grad)
        
        x = self.fc2(x)
        if self.store_activations:
            x = self._cache(x, self.cached_pre, self.cached_pre_grad)
        x = self.act_fn(x)
        if self.dropout:
            x = self.dropout(x)
        if self.store_activations:
            x = self._cache(x, self.cached_act, self.cached_act_grad)
        
        x = self.fc3(x)
        if self.store_activations:
            x = self._cache(x, self.cached_pre, self.cached_pre_grad)
        
        return x


def choose_optimizer(params, cfg: Dict) -> optim.Optimizer:
    name, lr = cfg.get("optimizer", "SGD"), cfg["lr"]
    if name == "SGD":
        return optim.SGD(params, lr)
    if name == "Adam":
        return optim.Adam(params, lr)
    if name in ["Hebb","DFA","RandomNN"]:
        return optim.SGD(params, lr)  # real update elsewhere
    raise ValueError(f"Unknown optimizer '{name}'.")

# -----------------------------------------------------------------------------
#  Alignment maths (unchanged) -------------------------------------------------
# -----------------------------------------------------------------------------
def cosine_similarity(a: torch.Tensor, b: torch.Tensor) -> float:
    a_f, b_f = a.flatten(), b.flatten()
    return float(torch.dot(a_f, b_f) / ((a_f.norm() * b_f.norm()) + 1e-8))



def add_param_noise(model, noise_frac):
    with torch.no_grad():
        for p in model.parameters():
            if p.numel() == 0:
                continue  # skip empty tensors
            noise = torch.randn_like(p)  * noise_frac
            p.add_(noise)


def add_input_noise(data, noise_frac):
    """Add noise to input data instead of model parameters."""
    with torch.no_grad():
        if noise_frac > 0.0:
            noise = torch.randn_like(data) * noise_frac
            return data + noise
        return data


def get_frozen_mask(param, model):
    """Get the frozen mask for a parameter from its containing module."""
    for name, p in model.named_parameters():
        if p is param:
            # Find the module that contains this parameter
            module_name = '.'.join(name.split('.')[:-1])
            if module_name:
                module = dict(model.named_modules())[module_name]
            else:
                module = model
            
            param_name = name.split('.')[-1]
            buffer_name = f'{param_name}_frozen_mask'
            if hasattr(module, buffer_name):
                return getattr(module, buffer_name)
            return None
    return None


@torch.no_grad()
def hebbian_update(model, x_batch, targets, cfg):
    lr = cfg["lr"]
    update_rule = cfg["update_rule"]
    mutate = (cfg["optimizer"] == "Hebb")
    
    # Flatten input batch
    acts = pre_synaptic_inputs(x_batch, model, cfg)    # Pre-activations for each layer
    pre_acts = model.cached_pre        # each element: (B, out_features)
    
    # Handle different model types
    if hasattr(model, 'layers'):
        # MLP/Transformer models
        layers = model.layers
        n_hidden = model.n_hidden
        # One-hot targets for final layer
        if cfg["model"] not in ["transformer", "regression-mlp"]:
            one_hot = F.one_hot(targets, model.layers[-1].out_features).float()
    elif isinstance(model, SmallCNN):
        # CNN models - only use linear layers for Hebbian updates
        layers = [model.fc1, model.fc2, model.fc3]
        n_hidden = 2  # fc1 and fc2 are hidden, fc3 is output
        # One-hot targets for final layer
        one_hot = F.one_hot(targets, model.fc3.out_features).float()
    else:
        raise ValueError(f"Unsupported model type: {type(model)}")

    B = x_batch.size(0)

    layer_updates = []
    for ℓ, layer in enumerate(layers):
        x = acts[ℓ]  # pre-synaptic activity: shape (B, in_features)
        

        if ℓ < n_hidden or (ℓ == n_hidden and cfg["model"] in ["transformer", "regression-mlp"]):
            # Pre-activation (before activation fn) for hidden layer
            # shape (B, out_features)
            # Hebbian term: average outer product
            if cfg["activation_update"] == "pre" or ℓ == n_hidden:
                y = pre_acts[ℓ]
            else:
                y = acts[ℓ+1]

            if update_rule.lower() == "oja":
                # Oja's rule: Δw = E[z xᵀ] - E[z²] * w
                delta_w = (y.t() @ (x - y @ layer.weight)) / x_batch.size(0)
                if mutate:
                    frozen_mask = get_frozen_mask(layer.weight, model)
                    if frozen_mask is not None:
                        delta_w[frozen_mask] = 0.0
                    layer.weight.data += lr * delta_w

            else:
                y_eff = _flatten_last(y)
                x_eff = _flatten_last(x)
                delta_w = (y_eff.t() @ x_eff) / y_eff.size(0)

                # Standard Hebb rule
                #delta_w = (y.t() @ x) / B 
                if mutate:
                    frozen_mask = get_frozen_mask(layer.weight, model)
                    if frozen_mask is not None:
                        delta_w[frozen_mask] = 0.0
                    layer.weight.data += lr * delta_w
                    layer.weight.data.div_(1e-8+layer.weight.data.std())
        else:
            # Use Hebbian rule for final layer too (not delta rule)
            if cfg["activation_update"] == "pre":
                y = pre_acts[ℓ]
            else:
                y = acts[ℓ+1] if ℓ+1 < len(acts) else pre_acts[ℓ]

            if update_rule.lower() == "oja":
                # Oja's rule: Δw = E[z xᵀ] - E[z²] * w
                delta_w = (y.t() @ (x - y @ layer.weight)) / x_batch.size(0)
            else:
                y_eff = _flatten_last(y)
                x_eff = _flatten_last(x)
                delta_w = (y_eff.t() @ x_eff) / y_eff.size(0)

            if mutate:
                frozen_mask = get_frozen_mask(layer.weight, model)
                if frozen_mask is not None:
                    delta_w[frozen_mask] = 0.0
                layer.weight.data += lr * delta_w
                layer.weight.data.div_(1e-8+layer.weight.data.std())
        layer_updates.append(delta_w) 

    return layer_updates