import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import *
import wandb
import math
from torchvision.models import resnet18, ResNet18_Weights

def compute_ext_loss(vals):
    return ext_histogram_divergence_loss(vals, 10, torch.min(vals).item(),
                                                    torch.max(vals).item(), sigma=(torch.max(vals).item() - torch.min(vals).item())/10)


class PReLU(nn.PReLU):
    """
    Custom PReLU class that allows for a different number of parameters for each neuron.
    """
    def __init__(self, num_parameters=1, init=0.25):
        super().__init__(num_parameters=num_parameters, init=init)

    def forward(self, input):
        #print(f"Input shape: {input.shape}")
        length = input.shape[0]
        batch = input.shape[1]
        input = input.view(length * batch, -1)  # Flatten the input
        input = super().forward(input)
        input = input.view(length, batch, -1)  # Reshape back to original dimensions
        return input

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        # Initialize the parent class
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        # Replace the activation function with PReLU:
        # Here, num_parameters=dim_feedforward ensures that every neuron in the feedforward network
        # gets its own learnable parameter for the negative slope.
        self.activation = PReLU(num_parameters=dim_feedforward)

class _FFN(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.lin1 = nn.Linear(d_model, d_ff)
        self.lin2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.lin2(F.gelu(self.lin1(x)))


class CrossAttnDecoderLayer(nn.Module):
    """
    One decoder layer with (optional) light query self-attn + cross-attn to encoder memory.
    Batch-first everywhere: (B, S, D).
    """
    def __init__(self, d_model: int, nhead: int, d_ff: int, use_query_self_attn: bool = False):
        super().__init__()
        self.use_qsa = use_query_self_attn
        if self.use_qsa:
            self.q_self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0, batch_first=True)
            self.norm_qsa = nn.LayerNorm(d_model)

        self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0, batch_first=True)
        self.norm_ca = nn.LayerNorm(d_model)

        self.ffn = _FFN(d_model, d_ff)
        self.norm_ffn = nn.LayerNorm(d_model)

    def forward(self, queries: torch.Tensor, memory: torch.Tensor):
        """
        queries: (B, Lq, D)   -- here Lq = total number of neurons (all positions)
        memory:  (B, Lm, D)   -- visible neuron tokens + image tokens after encoder
        """
        x = queries
        if self.use_qsa:
            x2, _ = self.q_self_attn(self.norm_qsa(x), self.norm_qsa(x), self.norm_qsa(x), need_weights=False)
            x = x + x2

        x2, _ = self.cross_attn(self.norm_ca(x), memory, memory, need_weights=False)
        x = x + x2

        x = x + self.ffn(self.norm_ffn(x))
        return x


class MAEStyleController(nn.Module):
    """
    MAE-style controller:
      - Encoder: runs ONLY on visible neuron tokens + image tokens.
      - Decoder: takes L neuron queries and cross-attends to encoder memory to produce L embeddings.
    """
    def __init__(
        self,
        d_model: int,
        nhead: int,
        enc_layers: int,
        dec_layers: int,
        d_ff: int,
        use_query_self_attn: bool = False,
        activation=F.leaky_relu,
    ):
        super().__init__()

        # --- Encoder (TransformerEncoder over memory tokens)
        enc_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=d_ff,
            dropout=0.0,
            activation=activation,
            batch_first=True,  # we use (B, S, D)
        )
        self.encoder = nn.TransformerEncoder(enc_layer, num_layers=enc_layers)

        # --- Decoder (stack of cross-attn layers)
        self.decoder = nn.ModuleList([
            CrossAttnDecoderLayer(d_model, nhead, d_ff, use_query_self_attn=use_query_self_attn)
            for _ in range(dec_layers)
        ])

    @torch.no_grad()
    def _maybe_cast(self, x):
        # Safety to keep all parts in same dtype/device as controller
        return x

    def forward(self, memory_tokens: torch.Tensor, queries: torch.Tensor) -> torch.Tensor:
        """
        memory_tokens: (B, M+I, D)  -- visible neuron tokens + image tokens
        queries:       (B, L,   D)  -- one query per neuron position (ALL neurons)

        returns:       (B, L, D)    -- one embedding per neuron, MAE-decoded
        """
        mem = self.encoder(memory_tokens)  # (B, M+I, D)
        out = queries
        for layer in self.decoder:
            out = layer(out, mem)          # (B, L, D)
        return out



class Single_NEW_MIX_Sample_Based_Sparse_Res(nn.Module):
    """
    For each neuron a representation is computed by concatenating:
      - A fixed part: computed from (layer index, neuron index) augmented with higher-frequency
        sin/cos positional encoding (similar to NeRF). This fixed part has dimension 2+4*num_freq.
      - A learnable part: stored as trainable parameters (2 values).
     
    The overall embedding is then processed by a single transformer (across all neurons),
    and the output embeddings are decoded to yield gating parameters (for conv layers: 4 values;
    for the FC layer: 1 value).
    """
    def __init__(self,
                 input_type='conv', 
                 controller_input_type ='conv',
                 input_shape=(3, 32, 32),
                 num_classes=10,
                 transformer_embed_dim=256,  # will be overridden
                 learnable_dim=4,  # learnable parameters per neuron
                 neuro_sync_conv_kernel_size=5,
                 k=1,           # number of transformer layers (global)
                 num_heads=2,
                 ema_decay=0.999,  # decay rate for EMA updates
                 num_freq=0,       # number of additional frequencies for positional encoding
                 prelu_transformer=True,
                 WC=True, 
                 SM=True,
                 AL=True,
                 ARM=True,
                 SM_detach=False,
                 attention_neuro_sync='mae',    # <--- NEW default to 'mae'
                 mask_portion=0.10,             # <--- NEW: portion of neurons per layer to KEEP
                 decoder_layers=None,           # <--- NEW: if None, uses k
                 load_pretrained=False,
                 dropout_percentage=0.0,
                 use_query_self_attn=False,
                 disable_bn=False,
                 device='cpu'
                ):
        super().__init__()
        
        # -------------------------------------------------------
        # -------------------------------------------------------
        # 1) Setup for gating neuron features.
        # -------------------------------------------------------
        # Build a list of neuron counts for all layers: input, conv layers, output.
        self.num_classes = num_classes
        self.input_type = input_type
        self.controller_input_type = controller_input_type
        self.input_shape = input_shape
        self.WC = WC
        self.SM = SM
        self.AL = AL
        self.ARM = ARM
        self.SM_detach = SM_detach
        self.device = device
        self.load_pretrained = load_pretrained
        self.dropout_percentage = dropout_percentage
        self.disable_bn = disable_bn

        self.attention_neuro_sync = attention_neuro_sync 

        self.build_network()

        print(f'Pretrained model is loaded: {self.load_pretrained}')
        if load_pretrained:
            self._load_pretrained_resnet18()
    
        if controller_input_type == 'conv':
            neuro_sync_encoder_size = conv_output_dim(input_shape[1], neuro_sync_conv_kernel_size, stride=neuro_sync_conv_kernel_size, padding=1)
            image_token_len = neuro_sync_encoder_size * neuro_sync_encoder_size

        self.register_buffer('activationmean_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('activationmax_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('activationmin_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('activationstd_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
    
        self.register_buffer('weightmean_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('weightmax_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('weightmin_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        self.register_buffer('weightstd_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW


        fixed_dim = 2 + 4 * num_freq  
        self.embed_dim = fixed_dim + learnable_dim + 4 + 4 # overall dimension for transformer input.

        self.query_in_dim = fixed_dim + learnable_dim
        self.query_proj = nn.Linear(self.query_in_dim, self.embed_dim)

        base_positions = []
        for layer_idx, (_, ch) in enumerate(self.layer_channels.items()):
            for neuron_idx in range(ch):
                base_positions.append([float(layer_idx), float(neuron_idx)])
        base_positions = torch.tensor(base_positions).to(self.device)  # shape: (total_nodes, 2)
    
        pos_encoded = self.positional_encoding(base_positions, num_freq)  # shape: (total_nodes, 2+4*num_freq)
        self.register_buffer('neuron_features', pos_encoded)

        def leaky_relu(x):
            return F.leaky_relu(x, negative_slope=0.01)

        act_fn = leaky_relu

        self.mask_portion = float(mask_portion)
        assert 0.0 <= self.mask_portion <= 1.0, "mask_portion must be in (0,1]."
        self.dec_layers = decoder_layers if decoder_layers is not None else k

        if attention_neuro_sync == 'mae':
            # Build MAE-style controller
            self.controller = MAEStyleController(
                d_model=self.embed_dim,
                nhead=num_heads,
                enc_layers=k,
                dec_layers=self.dec_layers,
                d_ff=transformer_embed_dim,
                use_query_self_attn=use_query_self_attn,     # keep decoder lightweight
                activation=act_fn
            )

        elif self.attention_neuro_sync == 'self':
            if prelu_transformer:
                encoder_layer = CustomTransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                            activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim)
                self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=k)
            else:
                encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                        activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim)
                self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=k)
        else:
            raise ValueError(f"{attention_neuro_sync} is not supported")

        if controller_input_type == 'conv':
            self.controller_image_encoder = nn.Conv2d(self.input_shape[0], self.embed_dim-learnable_dim, kernel_size=5, stride=5, padding=1)

        # Create learnable features for each neuron.
        self.learnable_features = nn.Parameter(torch.randn(self.total_nodes, learnable_dim).to(self.device) * 0.01)
        self.image_features = nn.Parameter(torch.randn(image_token_len, learnable_dim).to(self.device) * 0.01)
        
        self.learnable_dim = learnable_dim
        self.num_heads = num_heads
        self.k = k
    
        # Decoder for conv layers (decodes 5 values per neuron: c_vals, dynamic, sigma, miu, ema)
        self.decoder_conv = nn.Linear(self.embed_dim, 4, bias=True)
        self.decoder_fc = nn.Linear(self.embed_dim, 1, bias=True)

        if self.attention_neuro_sync == 'self' or self.attention_neuro_sync == 's4' or self.attention_neuro_sync == 'simple_mlp':
            num_params1 = sum(p.numel() for p in self.transformer.parameters())
        elif attention_neuro_sync == 'mae':
            num_params1 = sum(p.numel() for p in self.controller.parameters())
            num_params1 += sum(p.numel() for p in self.query_proj.parameters())
        elif self.attention_neuro_sync == 'cross_only':
                num_params1 = sum(p.numel() for p in self.cross_block.parameters())
        elif self.attention_neuro_sync == 'cross_fusion' or self.attention_neuro_sync == 'fusion_cross':
            num_params1 = sum(p.numel() for p in self.fusion_encoder.parameters())
            num_params1 += sum(p.numel() for p in self.cross_block.parameters())
            
        num_params1 += sum(p.numel() for p in self.decoder_conv.parameters())
        num_params1 += sum(p.numel() for p in self.decoder_fc.parameters())
        num_params1 += sum(p.numel() for p in self.learnable_features)
        print(f"Gating number of parameters: {num_params1}")

        num_params2 = sum(p.numel() for p in self.parameters())
        print(f"Total number of parameters: {num_params2}")
        print(f"Percentage of gating parameters: {num_params1 / num_params2}")
       
        # -------------------------------------------------------
        # 4) Save initial CNN branch parameters (for soft reset) and initialize EMA.
        # -------------------------------------------------------
        self._initial_params = {}
        self._ema_params = {}  # store exponential moving averages for each weight.
        for name, param in self.named_parameters():
            if name.startswith("conv") or name.startswith("fc"):
                self._initial_params[name] = param.detach().clone()
                self._ema_params[name] = param.detach().clone()  # initialize EMA
               
        self.ema_decay = ema_decay  # hyperparameter for EMA updates
    
    def _load_pretrained_resnet18(self):
        import torchvision.models as models

        # Load torchvision's pretrained ResNet-18
        try:
            # newer torchvision
            resnet = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
        except Exception:
            # older torchvision
            resnet = models.resnet18(pretrained=True)

        # Make sure we're on the same device / dtype
        device = next(self.parameters()).device
        resnet = resnet.to(device)

        # --------- Convolution layers (skip BLOCK-1) ---------
        conv_pairs = [
            # BLOCK-2  (ResNet layer1)
            (resnet.layer1[0].conv1, self.conv2_1_1),
            (resnet.layer1[0].conv2, self.conv2_1_2),
            (resnet.layer1[1].conv1, self.conv2_2_1),
            (resnet.layer1[1].conv2, self.conv2_2_2),

            # BLOCK-3  (ResNet layer2)
            (resnet.layer2[0].conv1,        self.conv3_1_1),
            (resnet.layer2[0].conv2,        self.conv3_1_2),
            (resnet.layer2[0].downsample[0], self.conv_concat_adjust_3),
            (resnet.layer2[1].conv1,        self.conv3_2_1),
            (resnet.layer2[1].conv2,        self.conv3_2_2),

            # BLOCK-4  (ResNet layer3)
            (resnet.layer3[0].conv1,        self.conv4_1_1),
            (resnet.layer3[0].conv2,        self.conv4_1_2),
            (resnet.layer3[0].downsample[0], self.conv_concat_adjust_4),
            (resnet.layer3[1].conv1,        self.conv4_2_1),
            (resnet.layer3[1].conv2,        self.conv4_2_2),

            # BLOCK-5  (ResNet layer4)
            (resnet.layer4[0].conv1,        self.conv5_1_1),
            (resnet.layer4[0].conv2,        self.conv5_1_2),
            (resnet.layer4[0].downsample[0], self.conv_concat_adjust_5),
            (resnet.layer4[1].conv1,        self.conv5_2_1),
            (resnet.layer4[1].conv2,        self.conv5_2_2),
        ]

        for src, dst in conv_pairs:
            dst.weight.data.copy_(src.weight.data)
            if dst.bias is not None and src.bias is not None:
                dst.bias.data.copy_(src.bias.data)

        # --------- BatchNorm layers ---------
        bn_pairs = [
            # BLOCK-2 (layer1)
            (resnet.layer1[0].bn1,          self.batchnorm2_1_1),
            (resnet.layer1[0].bn2,          self.batchnorm2_1_2),
            (resnet.layer1[1].bn1,          self.batchnorm2_2_1),
            (resnet.layer1[1].bn2,          self.batchnorm2_2_2),

            # BLOCK-3 (layer2)
            (resnet.layer2[0].bn1,          self.batchnorm3_1_1),
            (resnet.layer2[0].bn2,          self.batchnorm3_1_2),
            (resnet.layer2[0].downsample[1], self.batchnorm_adjust_3),
            (resnet.layer2[1].bn1,          self.batchnorm3_2_1),
            (resnet.layer2[1].bn2,          self.batchnorm3_2_2),

            # BLOCK-4 (layer3)
            (resnet.layer3[0].bn1,          self.batchnorm4_1_1),
            (resnet.layer3[0].bn2,          self.batchnorm4_1_2),
            (resnet.layer3[0].downsample[1], self.batchnorm_adjust_4),
            (resnet.layer3[1].bn1,          self.batchnorm4_2_1),
            (resnet.layer3[1].bn2,          self.batchnorm4_2_2),

            # BLOCK-5 (layer4)
            (resnet.layer4[0].bn1,          self.batchnorm5_1_1),
            (resnet.layer4[0].bn2,          self.batchnorm5_1_2),
            (resnet.layer4[0].downsample[1], self.batchnorm_adjust_5),
            (resnet.layer4[1].bn1,          self.batchnorm5_2_1),
            (resnet.layer4[1].bn2,          self.batchnorm5_2_2),
        ]

        for src, dst in bn_pairs:
            dst.weight.data.copy_(src.weight.data)
            dst.bias.data.copy_(src.bias.data)
            dst.running_mean.data.copy_(src.running_mean.data)
            dst.running_var.data.copy_(src.running_var.data)
            dst.num_batches_tracked.data.copy_(src.num_batches_tracked.data)

        # --------- AdaptiveAvgPool2d ---------
        # self.avgpool is already constructed as AdaptiveAvgPool2d((1, 1)),
        # and this layer has no learnable parameters, so nothing to copy.

        # --------- Fully connected layer ---------
        # Load fc weights if output size matches (e.g. num_classes == 1000).
        if (self.fc1.weight.shape == resnet.fc.weight.shape and
            self.fc1.bias.shape == resnet.fc.bias.shape):
            self.fc1.weight.data.copy_(resnet.fc.weight.data)
            self.fc1.bias.data.copy_(resnet.fc.bias.data)
        # If shapes don't match (e.g. different num_classes), we leave fc1
        # with its existing initialization so it can be trained from scratch.

    
    def build_network(self, double=False):
        # tailored for 32*32 images (CIFAR100 and ImageNet Tiny)
        # BLOCK-1 (starting block) input=(224x224) output=(56x56)
        if self.input_shape[-1] == 32:
            self.conv1 = torch.nn.Conv2d(3, 32 if double else 64, kernel_size=3, stride=1, padding=1)
            self.batchnorm1 = torch.nn.Identity()
            self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        elif self.input_shape[-1] == 224:
            self.conv1 = nn.Conv2d(in_channels=3, out_channels=32 if double else 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=True if self.disable_bn else False) 
            
            if self.disable_bn:
                self.batchnorm1 = nn.Identity()
            else:
                self.batchnorm1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True) 
            
            self.maxpool1 = nn.MaxPool2d(kernel_size=(3,3), stride=(2,2), padding=(1,1))
        elif self.input_shape[-1] == 84:
            self.conv1 = nn.Conv2d(3, 32 if double else 64, kernel_size=7, stride=2, padding=3, bias=True if self.disable_bn else False)
            self.batchnorm1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
            self.maxpool1 = nn.Identity()
        else:
            raise ValueError(f"{self.input_shape[-1]} is not supported in ResNet")
        
        # BLOCK-2 (1) input=(56x56) output = (56x56)
        self.conv2_1_1 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm2_1_1 = nn.Identity()
        else:
            self.batchnorm2_1_1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        self.conv2_1_2 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm2_1_2 = nn.Identity()
        else:
            self.batchnorm2_1_2 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        self.dropout2_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-2 (2)
        self.conv2_2_1 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm2_2_1 = nn.Identity()
        else:
            self.batchnorm2_2_1 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        
        self.conv2_2_2 = nn.Conv2d(in_channels=64, out_channels=32 if double else 64, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm2_2_2 = nn.Identity()
        else:
            self.batchnorm2_2_2 = nn.BatchNorm2d(32 if double else 64, eps=1e-05, momentum=0.1, affine=True)
        self.dropout2_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-3 (1) input=(56x56) output = (28x28)
        self.conv3_1_1 = nn.Conv2d(in_channels=64, out_channels=64 if double else 128, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm3_1_1 = nn.Identity()
        else:
            self.batchnorm3_1_1 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        self.conv3_1_2 = nn.Conv2d(in_channels=128, out_channels=64 if double else 128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm3_1_2 = nn.Identity()
        else:
            self.batchnorm3_1_2 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_3 = nn.Conv2d(in_channels=64, out_channels=64 if double else 128, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm_adjust_3 = nn.Identity()
        else:
            self.batchnorm_adjust_3 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        self.dropout3_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-3 (2)
        self.conv3_2_1 = nn.Conv2d(in_channels=128, out_channels=64 if double else 128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm3_2_1 = nn.Identity()
        else:
            self.batchnorm3_2_1 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        
        self.conv3_2_2 = nn.Conv2d(in_channels=128, out_channels=64 if double else 128, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm3_2_2 = nn.Identity()
        else:
            self.batchnorm3_2_2 = nn.BatchNorm2d(64 if double else 128, eps=1e-05, momentum=0.1, affine=True)
        self.dropout3_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-4 (1) input=(28x28) output = (14x14)
        self.conv4_1_1 = nn.Conv2d(in_channels=128, out_channels=128 if double else 256, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm4_1_1 = nn.Identity()
        else:
            self.batchnorm4_1_1 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.conv4_1_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm4_1_2 = nn.Identity()
        else:
            self.batchnorm4_1_2 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_4 = nn.Conv2d(in_channels=128, out_channels=128 if double else 256, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm_adjust_4 = nn.Identity()
        else:
            self.batchnorm_adjust_4 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.dropout4_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-4 (2)
        self.conv4_2_1 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm4_2_1 = nn.Identity()
        else:
            self.batchnorm4_2_1 = nn.BatchNorm2d(128 if double else 256, eps=1e-05, momentum=0.1, affine=True)
        self.conv4_2_2 = nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm4_2_2 = nn.Identity()
        else:
            self.batchnorm4_2_2 = nn.BatchNorm2d(128 if double else 256,eps=1e-05, momentum=0.1, affine=True)
        self.dropout4_2 = nn.Dropout(p=self.dropout_percentage)
        
        # BLOCK-5 (1) input=(14x14) output = (7x7)
        self.conv5_1_1 = nn.Conv2d(in_channels=256, out_channels=256 if double else 512, kernel_size=(3,3), stride=(2,2), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm5_1_1 = nn.Identity()
        else:
            self.batchnorm5_1_1 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.conv5_1_2 = nn.Conv2d(in_channels=512, out_channels=256 if double else 512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm5_1_2 = nn.Identity()
        else:
            self.batchnorm5_1_2 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.conv_concat_adjust_5 = nn.Conv2d(in_channels=256, out_channels=256 if double else 512, kernel_size=(1,1), stride=(2,2), padding=(0,0), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm_adjust_5 = nn.Identity()
        else:
            self.batchnorm_adjust_5 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.dropout5_1 = nn.Dropout(p=self.dropout_percentage)
        # BLOCK-5 (2)
        self.conv5_2_1 = nn.Conv2d(in_channels=512, out_channels=256 if double else 512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm5_2_1 = nn.Identity()
        else:
            self.batchnorm5_2_1 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.conv5_2_2 = nn.Conv2d(in_channels=512, out_channels=256 if double else 512, kernel_size=(3,3), stride=(1,1), padding=(1,1), bias=True if self.disable_bn else False)
        if self.disable_bn:
            self.batchnorm5_2_2 = nn.Identity()
        else:
            self.batchnorm5_2_2 = nn.BatchNorm2d(256 if double else 512, eps=1e-05, momentum=0.1, affine=True)
        self.dropout5_2 = nn.Dropout(p=self.dropout_percentage)
        
        # Final Block input=(7x7) 
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.fc1 = nn.Linear(in_features=512, out_features=self.num_classes)

        self.name_to_start_end = {

            "conv1": (0, 64),

            "conv2_1_1": (64, 128),
            "conv2_1_2": (128, 192),
            "conv2_2_1": (192, 256),
            "conv2_2_2": (256, 320),

            "conv3_1_1": (320, 448),
            "conv3_1_2": (448, 576),
            "conv_concat_adjust_3": (576, 704),
            "conv3_2_1": (704, 832),
            "conv3_2_2": (832, 960),

            "conv4_1_1": (960, 1216),
            "conv4_1_2": (1216, 1472),
            "conv_concat_adjust_4": (1472, 1728),
            "conv4_2_1": (1728, 1984),
            "conv4_2_2": (1984, 2240),

            "conv5_1_1": (2240, 2752),
            "conv5_1_2": (2752, 3264),
            "conv_concat_adjust_5": (3264, 3776),
            "conv5_2_1": (3776, 4288),
            "conv5_2_2": (4288, 4800),

            "fc1": (4800, 4800 + self.num_classes)
        }

        self.layer_channels = {
            name: end - start
            for name, (start, end) in self.name_to_start_end.items()
        }

        self.name_layers = list(self.name_to_start_end.keys())

        self.total_nodes = 4800 + self.num_classes    

    def forward_net(self, x, AL_list, ARM_list, updated_cnn_params, info):
        """
        Forward pass for the CNN branch using updated parameters.
        Also computes and returns channel-wise mean activations for conv layers.
        """
        act_means = {}  # store activation means per conv layer.
        
        x = conv2d_per_sample(x,
                        updated_cnn_params[f'conv1.weight'],
                        updated_cnn_params[f'conv1.bias'] if self.disable_bn else None,
                        padding=self.conv1.padding,
                        stride=self.conv1.stride)
        
        x = self.batchnorm1(x)

        x = self.channelwise_activation(x, AL_list['conv1'], ARM_list['conv1'], type='conv')

        act_means['conv1'] = x.detach()

        op1 = self.maxpool1(x)

        x = conv2d_per_sample(op1,
                    updated_cnn_params[f'conv2_1_1.weight'],
                    updated_cnn_params[f'conv2_1_1.bias'] if self.disable_bn else None,
                    padding=self.conv2_1_1.padding,
                    stride=self.conv2_1_1.stride)
        
        x = self.batchnorm2_1_1(x)

        x = self.channelwise_activation(x, AL_list['conv2_1_1'], ARM_list['conv2_1_1'], type='conv')

        act_means['conv2_1_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv2_1_2.weight'],
                    updated_cnn_params[f'conv2_1_2.bias'] if self.disable_bn else None,
                    padding=self.conv2_1_2.padding,
                    stride=self.conv2_1_2.stride)
        
        x = self.batchnorm2_1_2(x)

        act_means['conv2_1_2'] = x.detach()
        x = self.dropout2_1(x)

        op2_1 = self.channelwise_activation(x + op1, AL_list['conv2_1_2'], ARM_list['conv2_1_2'], type='conv')

        x = conv2d_per_sample(op2_1,
                    updated_cnn_params[f'conv2_2_1.weight'],
                    updated_cnn_params[f'conv2_2_1.bias'] if self.disable_bn else None,
                    padding=self.conv2_2_1.padding,
                    stride=self.conv2_2_1.stride)
        
        x = self.batchnorm2_2_1(x)

        x = self.channelwise_activation(x, AL_list['conv2_2_1'], ARM_list['conv2_2_1'], type='conv')
        
        act_means['conv2_2_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv2_2_2.weight'],
                    updated_cnn_params[f'conv2_2_2.bias'] if self.disable_bn else None,
                    padding=self.conv2_2_2.padding,
                    stride=self.conv2_2_2.stride)
        
        x = self.batchnorm2_2_2(x)

        act_means['conv2_2_2'] = x.detach()
        x = self.dropout2_2(x)

        op2 = self.channelwise_activation(x + op2_1, AL_list['conv2_2_2'], ARM_list['conv2_2_2'], type='conv')

        # BLOCK 3 START
        x = conv2d_per_sample(op2,
                    updated_cnn_params[f'conv3_1_1.weight'],
                    updated_cnn_params[f'conv3_1_1.bias'] if self.disable_bn else None,
                    padding=self.conv3_1_1.padding,
                    stride=self.conv3_1_1.stride)
        
        x = self.batchnorm3_1_1(x)

        x = self.channelwise_activation(x, AL_list['conv3_1_1'], ARM_list['conv3_1_1'], type='conv')

        act_means['conv3_1_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv3_1_2.weight'],
                    updated_cnn_params[f'conv3_1_2.bias'] if self.disable_bn else None,
                    padding=self.conv3_1_2.padding,
                    stride=self.conv3_1_2.stride)
        
        x = self.batchnorm3_1_2(x)

        act_means['conv3_1_2'] = x.detach()

        x = self.dropout3_1(x)

        # op2 = self.conv_concat_adjust_3(op2) # SKIP CONNECTION

        op2 = conv2d_per_sample(op2,
                    updated_cnn_params[f'conv_concat_adjust_3.weight'],
                    updated_cnn_params[f'conv_concat_adjust_3.bias'] if self.disable_bn else None,
                    padding=self.conv_concat_adjust_3.padding,
                    stride=self.conv_concat_adjust_3.stride)
        op2 = self.batchnorm_adjust_3(op2)

        act_means['conv_concat_adjust_3'] = op2.detach()        

        op3_1 = self.channelwise_activation(x + op2, AL_list['conv3_1_2'], ARM_list['conv3_1_2'], type='conv')        


        x = conv2d_per_sample(op3_1,
                    updated_cnn_params[f'conv3_2_1.weight'],
                    updated_cnn_params[f'conv3_2_1.bias'] if self.disable_bn else None,
                    padding=self.conv3_2_1.padding,
                    stride=self.conv3_2_1.stride)
        
        x = self.batchnorm3_2_1(x)

        x = self.channelwise_activation(x, AL_list['conv3_2_1'], ARM_list['conv3_2_1'], type='conv')        
        
        act_means['conv3_2_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv3_2_2.weight'],
                    updated_cnn_params[f'conv3_2_2.bias'] if self.disable_bn else None,
                    padding=self.conv3_2_2.padding,
                    stride=self.conv3_2_2.stride)
        
        x = self.batchnorm3_2_2(x)

        act_means['conv3_2_2'] = x.detach()

        x = self.dropout3_2(x)

        op3 = self.channelwise_activation(x + op3_1, AL_list['conv3_2_2'], ARM_list['conv3_2_2'], type='conv')  

        # BLOCK 3 END


        # BLOCK 4 START
        x = conv2d_per_sample(op3,
                    updated_cnn_params[f'conv4_1_1.weight'],
                    updated_cnn_params[f'conv4_1_1.bias'] if self.disable_bn else None,
                    padding=self.conv4_1_1.padding,
                    stride=self.conv4_1_1.stride)
        
        x = self.batchnorm4_1_1(x)

        x = self.channelwise_activation(x, AL_list['conv4_1_1'], ARM_list['conv4_1_1'], type='conv')

        act_means['conv4_1_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv4_1_2.weight'],
                    updated_cnn_params[f'conv4_1_2.bias'] if self.disable_bn else None,
                    padding=self.conv4_1_2.padding,
                    stride=self.conv4_1_2.stride)
        
        x = self.batchnorm4_1_2(x)

        act_means['conv4_1_2'] = x.detach()

        x = self.dropout4_1(x)

        # op3 = self.conv_concat_adjust_4(op3) # SKIP CONNECTION

        op3 = conv2d_per_sample(op3,
                    updated_cnn_params[f'conv_concat_adjust_4.weight'],
                    updated_cnn_params[f'conv_concat_adjust_4.bias'] if self.disable_bn else None,
                    padding=self.conv_concat_adjust_4.padding,
                    stride=self.conv_concat_adjust_4.stride)
        op3 = self.batchnorm_adjust_4(op3)

        act_means['conv_concat_adjust_4'] = op3.detach()  

        op4_1 = self.channelwise_activation(x + op3, AL_list['conv4_1_2'], ARM_list['conv4_1_2'], type='conv')        


        x = conv2d_per_sample(op4_1,
                    updated_cnn_params[f'conv4_2_1.weight'],
                    updated_cnn_params[f'conv4_2_1.bias'] if self.disable_bn else None,
                    padding=self.conv4_2_1.padding,
                    stride=self.conv4_2_1.stride)
        
        x = self.batchnorm4_2_1(x)

        x = self.channelwise_activation(x, AL_list['conv4_2_1'], ARM_list['conv4_2_1'], type='conv')        
        
        act_means['conv4_2_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv4_2_2.weight'],
                    updated_cnn_params[f'conv4_2_2.bias'] if self.disable_bn else None,
                    padding=self.conv4_2_2.padding,
                    stride=self.conv4_2_2.stride)
        
        x = self.batchnorm4_2_2(x)

        act_means['conv4_2_2'] = x.detach()

        x = self.dropout4_2(x)

        op4 = self.channelwise_activation(x + op4_1, AL_list['conv4_2_2'], ARM_list['conv4_2_2'], type='conv')  

        # BLOCK 4 END

        # BLOCK 5 START
        x = conv2d_per_sample(op4,
                    updated_cnn_params[f'conv5_1_1.weight'],
                    updated_cnn_params[f'conv5_1_1.bias'] if self.disable_bn else None,
                    padding=self.conv5_1_1.padding,
                    stride=self.conv5_1_1.stride)
        
        x = self.batchnorm5_1_1(x)

        x = self.channelwise_activation(x, AL_list['conv5_1_1'], ARM_list['conv5_1_1'], type='conv')

        act_means['conv5_1_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv5_1_2.weight'],
                    updated_cnn_params[f'conv5_1_2.bias'] if self.disable_bn else None,
                    padding=self.conv5_1_2.padding,
                    stride=self.conv5_1_2.stride)
        
        x = self.batchnorm5_1_2(x)

        act_means['conv5_1_2'] = x.detach()

        x = self.dropout5_1(x)

        # op4 = self.conv_concat_adjust_5(op4) # SKIP CONNECTION

        op4 = conv2d_per_sample(op4,
                    updated_cnn_params[f'conv_concat_adjust_5.weight'],
                    updated_cnn_params[f'conv_concat_adjust_5.bias'] if self.disable_bn else None,
                    padding=self.conv_concat_adjust_5.padding,
                    stride=self.conv_concat_adjust_5.stride)
        op4 = self.batchnorm_adjust_5(op4)

        act_means['conv_concat_adjust_5'] = op4.detach()  

        op5_1 = self.channelwise_activation(x + op4, AL_list['conv5_1_2'], ARM_list['conv5_1_2'], type='conv')        


        x = conv2d_per_sample(op5_1,
                    updated_cnn_params[f'conv5_2_1.weight'],
                    updated_cnn_params[f'conv5_2_1.bias'] if self.disable_bn else None,
                    padding=self.conv5_2_1.padding,
                    stride=self.conv5_2_1.stride)
        
        x = self.batchnorm5_2_1(x)

        x = self.channelwise_activation(x, AL_list['conv5_2_1'], ARM_list['conv5_2_1'], type='conv')        
        
        act_means['conv5_2_1'] = x.detach()

        x = conv2d_per_sample(x,
                    updated_cnn_params[f'conv5_2_2.weight'],
                    updated_cnn_params[f'conv5_2_2.bias'] if self.disable_bn else None,
                    padding=self.conv5_2_2.padding,
                    stride=self.conv5_2_2.stride)
        
        x = self.batchnorm5_2_2(x)

        act_means['conv5_2_2'] = x.detach()

        x = self.dropout5_2(x)

        op5 = self.channelwise_activation(x + op5_1, AL_list['conv5_2_2'], ARM_list['conv5_2_2'], type='conv')  

        # BLOCK 5 END

        x = self.avgpool(op5)
        x = x.reshape(x.shape[0], -1)
        x = linear_per_sample(x,
                                updated_cnn_params[f'fc1.weight'],
                                updated_cnn_params[f'fc1.bias'])
        act_means['fc1'] = x.detach()

    
        logits = x
        return logits, act_means, info



    def _sample_visible_indices(self) -> torch.LongTensor:
        """
        Returns 1D tensor of GLOBAL neuron indices to keep as 'visible' for the encoder,
        sampling independently per layer according to self.mask_portion.
        Always includes first/last neuron of each layer for stability if possible.
        """
        vis = []
        for layer_idx, (name, n) in enumerate(self.layer_channels.items()):
            k_keep = max(1, int(math.ceil(n * self.mask_portion)))
            base, _ = self.name_to_start_end[name]

            if n <= 2 or k_keep >= n:
                picked = torch.arange(n, device=self.device)
            else:
                # anchor endpoints
                anchors = torch.tensor([0, n - 1], device=self.device)
                need = k_keep - anchors.numel()
                need = max(0, need)
                pool = torch.arange(1, n - 1, device=self.device)
                perm = pool[torch.randperm(pool.numel(), device=self.device)]
                picked = torch.cat([anchors, perm[:need]], dim=0)

            vis.append(picked + base)
        return torch.cat(vis, dim=0).long().to(self.device)


    def load_consolidated_weights(self, wc_weights):
        for name, param in wc_weights.items():
            self._ema_params[name] = param.detach().clone()  # initialize EMA

    
    def forward_neuro_sync(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(1)

        # ----- image tokens (unchanged)
        if self.controller_input_type == 'conv':
            img_features = self.controller_image_encoder(x)  # (B, D_img, H', W')
            B = x.shape[0]
            img_tokens_core = img_features.view(B, self.embed_dim - self.learnable_dim, -1).transpose(1, 2)
            image_learnable = self.image_features.unsqueeze(0).expand(B, -1, -1)
            img_tokens = torch.cat([img_tokens_core, image_learnable], dim=-1)  # (B, I, D)

        # ----- neuron tokens (stats + pos + learnable)
        # Base (no batch): (N_total, D)
        neuron_full_tokens = torch.cat([
            self.neuron_features,           # (N, fixed_dim)
            self.learnable_features,        # (N, learnable_dim)
            self.activationmean_ema,        # (N, 1)
            self.activationmax_ema,         # (N, 1)
            self.activationmin_ema,         # (N, 1)
            self.activationstd_ema,         # (N, 1)
            self.weightmean_ema,            # (N, 1)
            self.weightmax_ema,             # (N, 1)
            self.weightmin_ema,             # (N, 1)
            self.weightstd_ema,             # (N, 1)
        ], dim=-1).to(self.device)  # -> (N, embed_dim)

        if self.attention_neuro_sync == 'mae':
            # ---- A) SAMPLE visibles for the encoder
            if self.mask_portion > 0:
                vis_idx = self._sample_visible_indices()  # (M_total,)
                vis_tokens = neuron_full_tokens.index_select(0, vis_idx)  # (M, D)
                B = x.shape[0]
                vis_tokens = vis_tokens.unsqueeze(0).expand(B, -1, -1)    # (B, M, D)

                # Memory for encoder = visibles + image tokens
                memory_tokens = torch.cat([vis_tokens, img_tokens], dim=1)  # (B, M+I, D)
            else:
                memory_tokens = img_tokens

            # ---- B) Build queries for ALL neurons: (pos + learnable) -> project -> D
            all_queries_base = torch.cat([self.neuron_features, self.learnable_features], dim=-1)  # (N, fixed+learnable)
            all_queries = self.query_proj(all_queries_base)                                        # (N, D)
            all_queries = all_queries.unsqueeze(0).expand(B, -1, -1)                              # (B, N, D)

            # ---- C) MAE-style encode/decode
            transformer_output = self.controller(memory_tokens, all_queries)  # (B, N, D)

        elif self.attention_neuro_sync == 'self':
            # (your original encoder-only path preserved)
            gating_features = neuron_full_tokens.unsqueeze(0).expand(x.shape[0], -1, -1)
            gating_features = torch.cat([gating_features, img_tokens], dim=1)   # (B, N+I, D)
            gating_features = gating_features.permute(1, 0, 2)                  # (N+I, B, D)
            transformer_output = self.transformer(gating_features).permute(1, 0, 2)[:, :self.total_nodes, :]
        else:
            raise ValueError(f"{self.attention_neuro_sync} is not supported")

        return transformer_output  # (B, N, D)


    def forward(self, x):
        """
        x: (B, 3, 32, 32)
        Returns: (logits, total_loss)
        """
        self.activations = {}
        info = {}
        
        transformer_output = self.forward_neuro_sync(x)
        # ---- B) Decode gating for conv layers.
        ARM_list = {}
        AL_list = {}         # list of gating values for conv layers.
        WC_list = {}       # list of neuron-specific EMA factors.
        SM_list = {}     # list of neuron-specific modulation factors.
        bound_loss = 0.0
        uniform_loss = 0.0
        batch_size = x.shape[0]

        for name in self.name_layers[:-1]:
            (start, end) = self.name_to_start_end[name]
            
            current_embedding = transformer_output[:, start:end]  # (n_neurons, embed_dim)
            gating_out = self.decoder_conv(current_embedding)  # (n_neurons, 5)
            AL_vals = gating_out[:, :, 0] * int(self.AL)
            ARM_vals = gating_out[:, :, 1] * int(self.ARM)
            WC_vals = F.softplus(gating_out[:, :, 2] * 0.5) * int(self.WC)
            SM_vals = gating_out[:, :, 3] * int(self.SM)
        
           
            ARM_list[name] = ARM_vals
            AL_list[name] = AL_vals
            WC_list[name] = WC_vals
            SM_list[name] = SM_vals
            
            uniform_loss += compute_ext_loss(WC_vals)
            uniform_loss += compute_ext_loss(SM_vals)
            uniform_loss += compute_ext_loss(ARM_vals)
            uniform_loss += compute_ext_loss(AL_vals)

            bound_loss += ext_soft_max(WC_vals) + ext_soft_min(WC_vals)
            bound_loss += ext_soft_max(ARM_vals) + ext_soft_min(ARM_vals)
            bound_loss += ext_soft_max(AL_vals) + ext_soft_min(AL_vals)
            bound_loss += ext_soft_max(SM_vals) + ext_soft_min(SM_vals)

        # ---- C) Decode gating for the output (fc) layer.
        (start, end) = self.name_to_start_end[self.name_layers[-1]]
        
        fc_embedding = transformer_output[:, start:end]
        fc_gating = self.decoder_fc(fc_embedding)
        fc_WC_vals = F.softplus(fc_gating[:, :, 0] * 0.5) * int(self.WC)

        uniform_loss += compute_ext_loss(fc_WC_vals)
        
        bound_loss += ext_soft_max(torch.abs(fc_WC_vals)) #+ ext_soft_min(fc_WC_vals)
       
        # ---- D) Update CNN branch parameters (conv and fc) in a differentiable manner.
        updated_cnn_params = {}
        for name, param in self.named_parameters():
            if not (name.startswith("conv") or name.startswith("fc")):
                continue
            name_in_list = name.split('.')[0]
            if name.startswith("conv"):
                WC_factor = WC_list[name_in_list] # shape: (n_neurons,)
                SM_factor = SM_list[name_in_list]
                n_neurons = self.layer_channels[name_in_list]
            
            elif name.startswith("fc"):
                if name_in_list == self.name_layers[-1]:
                    WC_factor = fc_WC_vals
                    SM_factor = torch.zeros_like(fc_WC_vals).to(self.device)
                else:
                    WC_factor = WC_list[name_in_list] # shape: (n_neurons,)
                    SM_factor = SM_list[name_in_list]

                n_neurons = self.layer_channels[name_in_list]
            else:
                print(f"Unexpected parameter name: {name}")
            if SM_factor is not None and param.shape[0] == n_neurons:
                batch_size = x.shape[0]
                new_shape = [batch_size, n_neurons] + [1] * (param.dim() - 1)
                WC_factor = WC_factor.view(*new_shape)
                SM_factor = SM_factor.view(*new_shape)
            else:
                raise ValueError('due to missing dynamic values')
                # print(f"Skipping parameter {name} due to missing dynamic values.")
                # WC_factor = WC_factor.mean() if WC_factor is not None else 0.0
                # SM_factor = SM_factor.mean() if SM_factor is not None else 0.0
           
            ema_val = self._ema_params[name].to(self.device).clone().detach().unsqueeze(0)
           
            # info = {**info, f'alpha/SM_{name}': wandb.Histogram(SM_factor.view(-1).detach().cpu().numpy()), 
            #         f'alpha/WC_{name}': wandb.Histogram(WC_factor.view(-1).detach().cpu().numpy())}
            
            param = param.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)  # shape: (B, n_neurons, ...)
            if self.SM_detach:
                updated_cnn_params[name] = param + param.clone().detach() * SM_factor + ema_val * WC_factor
            else:
                updated_cnn_params[name] = (1 + SM_factor) * param + ema_val * WC_factor

        logits, acts, info = self.forward_net(x, AL_list, ARM_list, updated_cnn_params, info=info)
        
        self.update_ema()

        # Update the per-neuron EMA buffers for activation and weight norm.
        with torch.no_grad():
            # Update activation EMA for conv layers and fc.
            for _, name in enumerate(self.name_layers):
                (start, end) = self.name_to_start_end[name]
                mean_old_val = self.activationmean_ema[start:end]
                max_old_val = self.activationmax_ema[start:end]
                min_old_val = self.activationmin_ema[start:end]
                std_old_val = self.activationstd_ema[start:end]
                if name.startswith("conv"):
                    mean_new_val = 0.999 * mean_old_val + 0.001 * acts[name].mean(dim=[0,2,3]).view(-1, 1)
                    max_new_val = 0.999 * max_old_val + 0.001 * torch.amax(acts[name], dim=(0,2,3)).view(-1, 1)
                    min_new_val = 0.999 * min_old_val + 0.001 * torch.amin(acts[name], dim=(0,2,3)).view(-1, 1)
                    std_new_val = 0.999 * std_old_val + 0.001 * acts[name].std(dim=[0,2,3]).view(-1, 1)
                elif name.startswith("fc"):
                    mean_new_val = 0.999 * mean_old_val + 0.001 * acts[name].mean(dim=[0]).view(-1, 1)
                    max_new_val = 0.999 * max_old_val + 0.001 * torch.amax(acts[name], dim=(0)).view(-1, 1)
                    min_new_val = 0.999 * min_old_val + 0.001 * torch.amin(acts[name], dim=(0)).view(-1, 1)
                    if acts[name].shape[0] == 1:
                        std_new_val = 0.999 * std_old_val + 0.001 * acts[name].std(dim=0, unbiased=False).view(-1, 1)
                    else:
                        std_new_val = 0.999 * std_old_val + 0.001 * acts[name].std(dim=0).view(-1, 1)
                else:
                    raise ValueError(f"{name} is not valid")
                 
                self.activationmean_ema[start:end] = mean_new_val.detach()  # NEW
                self.activationmax_ema[start:end] = max_new_val.detach()  # NEW
                self.activationmin_ema[start:end] = min_new_val.detach()  # NEW
                self.activationstd_ema[start:end] = std_new_val.detach()  # NEW
            
        
            # Update weight norm EMA for conv and fc layers.
            for name, param in updated_cnn_params.items():
                if name.startswith("conv") or name.startswith("fc"):
                    if name.endswith("weight"):
                        param = param.mean(0)
                        #breakpoint()
                        # For conv: (out_channels, in_channels, kh, kw); for fc: (out_features, in_features)
                        mean_val = param.view(param.shape[0], -1).mean(dim=1).detach()  # NEW
                        max_val = torch.amax(param.view(param.shape[0], -1), dim=1).detach()  # NEW
                        min_val = torch.amin(param.view(param.shape[0], -1), dim=1).detach()  # NEW
                        std_val = param.view(param.shape[0], -1).std(dim=1).detach()  # NEW
                        (start, end) = self.name_to_start_end[name.split('.')[0]]
                        mean_old_val = self.weightmean_ema[start:end]
                        max_old_val = self.weightmax_ema[start:end]
                        min_old_val = self.weightmin_ema[start:end]
                        std_old_val = self.weightstd_ema[start:end]
                        mean_new_val = 0.999 * mean_old_val + 0.001 * mean_val.view(-1, 1)
                        max_new_val = 0.999 * max_old_val + 0.001 * max_val.view(-1, 1)
                        min_new_val = 0.999 * min_old_val + 0.001 * min_val.view(-1, 1)
                        std_new_val = 0.999 * std_old_val + 0.001 * std_val.view(-1, 1)

                        self.weightmean_ema[start:end] = mean_new_val.detach()  # NEW
                        self.weightmax_ema[start:end] = max_new_val.detach()  # NEW
                        self.weightmin_ema[start:end] = min_new_val.detach()  # NEW
                        self.weightstd_ema[start:end] = std_new_val.detach()  # NEW
    
        info['uniform_loss'] = uniform_loss
        info['bound_loss'] = bound_loss.mean(0)
        info['hidden'] = acts[self.name_layers[-2]]
        return logits, info

    def channelwise_activation(self, x, c_vector, miu, type):
        """
        Applies gating to the channels of x according to c_vector.
        x: (B, Channels, H, W)
        c_vector: (Channels,)
        """
        if type == 'conv':
            miu = miu.unsqueeze(-1).unsqueeze(-1)
            c_shaped = c_vector.unsqueeze(-1).unsqueeze(-1)
        elif type == 'fc':
            miu = miu
            c_shaped = c_vector
        relu_out = F.relu(x)
        determinstic = relu_out + c_shaped * x * (x < 0).float()
        result = determinstic + miu
        return result
   
    @staticmethod
    def positional_encoding(x, num_freq):
        """
        x: (N, 2) tensor of base positions.
        Returns: (N, 2 + 4*num_freq) tensor.
        For each coordinate, include the original value and for each frequency i compute:
           sin(2^i * coordinate) and cos(2^i * coordinate)
        """
        encodings = [x]  # include the original (x, y)
        for i in range(num_freq):
            frequency = 2.0 ** i
            encodings.append(torch.sin(frequency * x))
            encodings.append(torch.cos(frequency * x))
        return torch.cat(encodings, dim=-1)
   
    def update_ema(self):
        for name, param in self.named_parameters():
            if not (name.startswith("conv") or name.startswith("fc")):
                continue
            with torch.no_grad():
                self._ema_params[name] = self.ema_decay * self._ema_params[name] + (1 - self.ema_decay) * param.detach().cpu()

    def plot_params(self):
        log_info_neuro_sync = {}
        return log_info_neuro_sync

    def get_model_weights_l2_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return utils_l2_norm(filtered_params)

    def compute_l1_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return utils_l1_norm(filtered_params)
    
    def compute_l2_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return new_utils_l2_norm(filtered_params)
    
    def compute_total_params(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        total_params = 0
        for name, param in filtered_params:
            if 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                    total_params += param.numel()

        return total_params

    
