import torch
import torch.nn as nn
import torch.nn.functional as F
from utils.utils import *
import wandb
import math
from s4torch import S4Model
from net.transformer import CrossAttentionBlock


class PReLU(nn.PReLU):
    """
    Custom PReLU class that allows for a different number of parameters for each neuron.
    """
    def __init__(self, num_parameters=1, init=0.25):
        super().__init__(num_parameters=num_parameters, init=init)

    def forward(self, input):
        #print(f"Input shape: {input.shape}")
        length = input.shape[0]
        batch = input.shape[1]
        input = input.view(length * batch, -1)  # Flatten the input
        input = super().forward(input)
        input = input.view(length, batch, -1)  # Reshape back to original dimensions
        return input

class CustomTransformerEncoderLayer(nn.TransformerEncoderLayer):
    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu"):
        # Initialize the parent class
        super().__init__(d_model, nhead, dim_feedforward, dropout, activation)
        # Replace the activation function with PReLU:
        # Here, num_parameters=dim_feedforward ensures that every neuron in the feedforward network
        # gets its own learnable parameter for the negative slope.
        self.activation = PReLU(num_parameters=dim_feedforward)

class CReLU(nn.Module):

    def __init__(self, inplace=False):
        super(CReLU, self).__init__()

    def forward(self, x):
        if len(x.shape) == 2:
            x = torch.cat((x,-x),-1)
        elif len(x.shape) == 4:
            x = torch.cat((x,-x), 1)
        else: 
            raise f"{x.shpe} is invalid in CReLU"
        return F.relu(x)

class MLPController(nn.Module):
    def __init__(self, image_size, num_neurons, embed_dim):
        super().__init__() 
        print('**** MLPController ****')
        self.model_controller = nn.Sequential(nn.Linear((num_neurons + image_size) * embed_dim , 64),
                                              CReLU(), 
                                              nn.Linear(128, 16),
                                              CReLU(),
                                              nn.Linear(32, num_neurons * embed_dim))
    def forward(self, x):
        return self.model_controller(x)

# # different decoders for cnn and fc? 
class Single_NEW_MIX_Sample_Based(nn.Module):
    """
    Deeper CNN with a global transformer-based gating using hybrid node representations.


    For each neuron a representation is computed by concatenating:
      - A fixed part: computed from (layer index, neuron index) augmented with higher-frequency
        sin/cos positional encoding (similar to NeRF). This fixed part has dimension 2+4*num_freq.
      - A learnable part: stored as trainable parameters (2 values).
     
    The overall embedding is then processed by a single transformer (across all neurons),
    and the output embeddings are decoded to yield gating parameters (for conv layers: 5 values;
    for the FC layer: 3 values).
    """
    def __init__(self,
                 input_type='conv', 
                 controller_input_type ='conv',
                 input_shape=(3, 32, 32),
                 num_classes=10,
                 cnn_channels=[8, 16, 32, 64],
                 kernel_size=[3, 3, 3, 3],
                 padding=[1, 1, 1, 1],
                 stride=[1, 1, 1, 1],
                 pooling_type=['max', 'max', 'max', 'max'],
                 pooling_kernel=[2, 2, 2, 2],
                 fc_channels=[],
                 transformer_embed_dim=256,  # will be overridden
                 learnable_dim=4,  # learnable parameters per neuron
                 neuro_sync_conv_kernel_size=5,
                 k=1,           # number of transformer layers (global)
                 num_heads=2,
                 ema_decay=0.999,  # decay rate for EMA updates
                 num_freq=0,       # number of additional frequencies for positional encoding
                 last_layer_act=False,
                 prelu_transformer=True,
                 WC=True, 
                 SR=True,
                 SM=True,
                 AL=True,
                 ARM=True,
                 SM_detach=False,
                 simplified=False,
                 attention_neuro_sync='self',
                 track_params=True,
                 device='cpu'
                ):
        super().__init__()
        
        # -------------------------------------------------------
        # -------------------------------------------------------
        # 1) Setup for gating neuron features.
        # -------------------------------------------------------
        # Build a list of neuron counts for all layers: input, conv layers, output.
        fc_channels = fc_channels + [num_classes]
        self.layer_channels = cnn_channels + fc_channels
        self.pooling_type = pooling_type
        self.pooling_kernel = pooling_kernel
        self.cnn_channels = cnn_channels
        self.fc_channels = fc_channels
        self.padding = padding
        self.stride = stride
        self.input_type = input_type
        self.controller_input_type = controller_input_type
        self.track_params = track_params
        self.input_shape = input_shape
        self.last_layer_act = last_layer_act
        self.WC = WC
        self.SR = SR
        self.SM = SM
        self.AL = AL
        self.ARM = ARM
        self.SM_detach = SM_detach
        self.simplified = simplified
        self.device = device
        self.prelu_transformer = prelu_transformer
        self.num_heads = num_heads
        self.transformer_embed_dim = transformer_embed_dim
        self.k = k

        self.attention_neuro_sync = attention_neuro_sync
        
        offsets = []
        running_sum = 0
        for c in self.layer_channels:
            offsets.append(running_sum)
            running_sum += c
        self.offsets = offsets
        self.total_nodes = running_sum

        if not self.simplified:
            if controller_input_type == 'conv':
                neuro_sync_encoder_size = conv_output_dim(input_shape[1], neuro_sync_conv_kernel_size, stride=neuro_sync_conv_kernel_size, padding=1)
                image_token_len = neuro_sync_encoder_size * neuro_sync_encoder_size
            elif controller_input_type == 'fc':
                image_token_len = input_shape[0]
            
            self.image_token_len = image_token_len

            self.register_buffer('activationmean_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('activationmax_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('activationmin_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('activationstd_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
        
            self.register_buffer('weightmean_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('weightmax_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('weightmin_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW
            self.register_buffer('weightstd_ema', torch.zeros(self.total_nodes, 1).to(self.device))  # shape: (total_nodes, 1)  #NEW


            fixed_dim = 2 + 4 * num_freq  
            self.embed_dim = fixed_dim + learnable_dim + 4 + 4 # overall dimension for transformer input.

            base_positions = []
            for layer_idx, ch in enumerate(self.layer_channels):
                for neuron_idx in range(ch):
                    base_positions.append([float(layer_idx), float(neuron_idx)])
            base_positions = torch.tensor(base_positions).to(self.device)  # shape: (total_nodes, 2)
        
            pos_encoded = self.positional_encoding(base_positions, num_freq)  # shape: (total_nodes, 2+4*num_freq)
            self.register_buffer('neuron_features', pos_encoded)

            self.create_controller()

            if controller_input_type == 'conv':
                self.controller_image_encoder = nn.Conv2d(self.input_shape[0], self.embed_dim-learnable_dim, kernel_size=5, stride=5, padding=1)
            elif controller_input_type == 'fc':
                self.controller_image_encoder = nn.Conv1d(1, self.embed_dim-learnable_dim, kernel_size=1, stride=1, padding=0)
                # nn.Sequential(nn.Flatten(),
                #                                             nn.Linear(math.prod(input_shape), image_token_len * (self.embed_dim-learnable_dim)))
            else:
                raise ValueError(f"Unknown input type: {controller_input_type}")

            # Create learnable features for each neuron.
            self.learnable_features = nn.Parameter(torch.randn(self.total_nodes, learnable_dim).to(self.device) * 0.01)
            self.image_features = nn.Parameter(torch.randn(image_token_len, learnable_dim).to(self.device) * 0.01)
            
            self.learnable_dim = learnable_dim
            self.num_heads = num_heads
            self.k = k
        
            # Decoder for conv layers (decodes 5 values per neuron: c_vals, dynamic, sigma, miu, ema)
            self.decoder_conv = nn.Linear(self.embed_dim, 6, bias=True)

            if last_layer_act:
                self.decoder_fc = self.decoder_conv
            else:
                self.decoder_fc = nn.Linear(self.embed_dim, 2, bias=True)

            if self.attention_neuro_sync == 'self' or self.attention_neuro_sync == 's4' or self.attention_neuro_sync == 'simple_mlp':
                num_params1 = sum(p.numel() for p in self.transformer.parameters())
            elif self.attention_neuro_sync == 'cross_only':
                 num_params1 = sum(p.numel() for p in self.cross_block.parameters())
            elif self.attention_neuro_sync == 'cross_fusion' or self.attention_neuro_sync == 'fusion_cross':
                num_params1 = sum(p.numel() for p in self.fusion_encoder.parameters())
                num_params1 += sum(p.numel() for p in self.cross_block.parameters())
                
            num_params1 += sum(p.numel() for p in self.decoder_conv.parameters())
            if not last_layer_act:
                num_params1 += sum(p.numel() for p in self.decoder_fc.parameters())
            num_params1 += sum(p.numel() for p in self.learnable_features)
            #print(f"Gating number of parameters: {num_params1}")

        else:
            num_params1 = 0
            if self.WC:
                self.WC_params = nn.Parameter(torch.randn(self.total_nodes).to(self.device) * 0.01)
                num_params1 += self.total_nodes
            else:
                self.WC_params = torch.zeros(self.total_nodes).to(self.device).requires_grad_(False) * 0.01

            if self.SR:
                self.SR_params = nn.Parameter(torch.randn(self.total_nodes).to(self.device) * 0.01)
                num_params1 += self.total_nodes
            else:
                self.SR_params = torch.zeros(self.total_nodes).to(self.device).requires_grad_(False) * 0.01
            
            if self.SM:
                self.SM_params = nn.Parameter(torch.randn(self.total_nodes).to(self.device) * 0.01)
                num_params1 += self.total_nodes
            else:
                self.SM_params = torch.zeros(self.total_nodes).to(self.device).requires_grad_(False) * 0.01

            if self.AL:
                self.AL_params = nn.Parameter(torch.randn(self.total_nodes).to(self.device) * 0.01)
                num_params1 += self.total_nodes
            else:
                self.AL_params = torch.zeros(self.total_nodes).to(self.device).requires_grad_(False) * 0.01
            
            if self.ARM:
                self.ARM_params = nn.Parameter(torch.randn(self.total_nodes).to(self.device) * 0.01)
                num_params1 += self.total_nodes
            else:
                self.ARM_params = torch.zeros(self.total_nodes).to(self.device).requires_grad_(False) * 0.01
            
            self.decoder_conv = nn.Identity()
            self.decoder_fc = nn.Identity()

        in_channels = input_shape[0]
        for i, out_channels in enumerate(cnn_channels):
            conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size[i], padding=padding[i], stride=stride[i])
            setattr(self, f'conv{i+1}', conv_layer)
            in_channels = out_channels 
        
        if input_type == 'conv':
            layers = []
            for i in range(len(self.cnn_channels)):
                layers.append({
                    'kernel_size': kernel_size[i], 'stride': stride[i], 'padding': padding[i], 'filters': cnn_channels[i], 'pool_size': pooling_kernel[i]})
            #print(layers)
            output_shapes = calculate_network_dims(input_shape[1],
                                   input_shape[2],
                                   input_shape[0],
                                   layers
                                   )
            fc_in = output_shapes[-1]
        else:
            fc_in = math.prod(input_shape)
        for i, out_dim in enumerate(fc_channels):
            fc_layer = nn.Linear(fc_in, out_dim)
            setattr(self, f'fc{i+1}', fc_layer)
            fc_in = out_dim

        num_params2 = sum(p.numel() for p in self.parameters())
        #print(f"Total number of parameters: {num_params2}")
        #print(f"Percentage of gating parameters: {num_params1 / num_params2}")
       
        # -------------------------------------------------------
        # 4) Save initial CNN branch parameters (for soft reset) and initialize EMA.
        # -------------------------------------------------------
        self._initial_params = {}
        self._ema_params = {}  # store exponential moving averages for each weight.
        for name, param in self.named_parameters():
            if name.startswith("conv") or name.startswith("fc"):
                self._initial_params[name] = param.detach().clone()
                self._ema_params[name] = param.detach().clone()  # initialize EMA
               
        self.ema_decay = ema_decay  # hyperparameter for EMA updates

        self.stored_c_vals = []
        self.stored_dynamic_vals = []
        self.stored_sigma_vals = []
        self.stored_miu_vals = []
        self.stored_ema_vals = []
        self.stored_modulation_vals = []

    def create_controller(self):
        prelu_transformer = self.prelu_transformer
        num_heads = self.num_heads 
        transformer_embed_dim = self.transformer_embed_dim
        k = self.k
        image_token_len = self.image_token_len
        def leaky_relu(x):
            return F.leaky_relu(x, negative_slope=0.01)

        if self.attention_neuro_sync == 'self':
            if prelu_transformer:
                encoder_layer = CustomTransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                            activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim)
                self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=k)
            else:
                encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                        activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim)
                self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=k)
        
        elif self.attention_neuro_sync == 's4':
            self.transformer = S4Model(d_input=self.embed_dim, d_model=48, d_output=self.embed_dim, n=64, n_blocks=1, l_max=self.total_nodes + image_token_len)
            
        elif self.attention_neuro_sync == 'simple_mlp':
            self.transformer = MLPController(image_size=image_token_len, num_neurons=self.total_nodes, embed_dim=self.embed_dim)
        
        elif self.attention_neuro_sync == 'cross_fusion' or self.attention_neuro_sync == 'fusion_cross':
            if prelu_transformer:
                encoder_layer = CustomTransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                            activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim//2)
                self.fusion_encoder = nn.TransformerEncoder(encoder_layer, num_layers=k)
            else:
                encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_dim, nhead=num_heads,
                                                        activation=leaky_relu, dropout=0.0, dim_feedforward=transformer_embed_dim//2)
                self.fusion_encoder = nn.TransformerEncoder(encoder_layer, num_layers=k)

            # --- 3) Cross‐Attention: other queries vision ---
            self.cross_block = CrossAttentionBlock(
                dim_q=self.embed_dim,
                dim_kv=self.embed_dim,
                num_heads=num_heads,
                mlp_dim=transformer_embed_dim//2,
                dropout=0.0
            )
        elif self.attention_neuro_sync == 'cross_only':
            # --- 3) Cross‐Attention: other queries vision ---
            self.cross_block = CrossAttentionBlock(
                dim_q=self.embed_dim,
                dim_kv=self.embed_dim,
                num_heads=num_heads,
                mlp_dim=transformer_embed_dim,
                dropout=0.0
            )

    def load_consolidated_weights(self, wc_weights):
        for name, param in wc_weights.items():
            self._ema_params[name] = param.detach().clone()  # initialize EMA

    def forward_neuro_sync(self, x):
        if len(x.shape) == 2:
            x = x.unsqueeze(1)
        if self.attention_neuro_sync == 'self':
            #breakpoint()
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            gating_features = torch.cat([gating_features, img_tokens], dim=1)  # shape: (batch, total_nodes + image_size*image_size, embed_dim)
            gating_features = gating_features.permute(1, 0, 2)  # shape: (batch, embed_dim, total_nodes + image_size*image_size)
            transformer_output = self.transformer(gating_features)#.squeeze(1)  # (batch, total_nodes, embed_dim)
            transformer_output = transformer_output.permute(1, 0, 2)  # shape: (batch, total_nodes + image_size*image_size, embed_dim)
            transformer_output = transformer_output[:, :self.total_nodes, :]
            
        
        elif self.attention_neuro_sync == 'simple_mlp':
            #breakpoint()
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            #breakpoint()
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            gating_features = torch.cat([gating_features, img_tokens], dim=1)  # shape: (batch, total_nodes + image_size*image_size, embed_dim)
            #gating_features = gating_features.permute(1, 0, 2)  # shape: (batch, embed_dim, total_nodes + image_size*image_size)
            transformer_output = self.transformer(gating_features.view(gating_features.shape[0], -1))#.squeeze(1)  # (batch, total_nodes, embed_dim)
            transformer_output = transformer_output.view((transformer_output.shape[0], self.total_nodes, -1))
        
        elif self.attention_neuro_sync == 's4':
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            gating_features = torch.cat([gating_features, img_tokens], dim=1)  # shape: (batch, total_nodes + image_size*image_size, embed_dim)
            #gating_features = gating_features.permute(0, 1, 2)  # shape: (batch, embed_dim, total_nodes + image_size*image_size)
            #breakpoint()
            transformer_output = self.transformer(gating_features)#.squeeze(1)  # (batch, total_nodes, embed_dim)
            #transformer_output = transformer_output#.permute(1, 0, 2)  # shape: (batch, total_nodes + image_size*image_size, embed_dim)
            transformer_output = transformer_output[:, :self.total_nodes, :]
        
        elif self.attention_neuro_sync == 'fusion_cross':
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            gating_features = self.fusion_encoder(gating_features)
            #breakpoint()
            gating_features  = gating_features.permute(1, 0, 2) 
            img_tokens  = img_tokens.permute(1, 0, 2) 
            transformer_output = self.cross_block(gating_features, img_tokens)
            transformer_output = transformer_output.permute(1, 0, 2)
        
        elif self.attention_neuro_sync == 'cross_fusion':
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            gating_features  = gating_features.permute(1, 0, 2) 
            img_tokens  = img_tokens.permute(1, 0, 2) 
            gating_features = self.cross_block(gating_features, img_tokens)
            transformer_output = self.fusion_encoder(gating_features)
            transformer_output = transformer_output.permute(1, 0, 2)
            
        elif self.attention_neuro_sync == 'cross_only':
            img_features = self.controller_image_encoder(x)
            # Flatten spatial dimensions and transpose to get tokens:
            img_tokens = img_features.view(x.shape[0], self.embed_dim-self.learnable_dim, -1).transpose(1, 2)  # shape: (batch, image_size*image_size, d_model-2)
            image_learnable = self.image_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, image_size*image_size, 2)
            img_tokens = torch.cat([img_tokens, image_learnable], dim=-1)
            # ---- A) Compute global gating features with positional encoding.
            gating_features = torch.cat([self.neuron_features,
                                        self.learnable_features,
                                        self.activationmean_ema,
                                        self.activationmax_ema,
                                        self.activationmin_ema,
                                        self.activationstd_ema,
                                        self.weightmean_ema,
                                        self.weightmax_ema,
                                        self.weightmin_ema,
                                        self.weightstd_ema],
                                        dim=-1)  # (total_nodes, embed_dim)
            
            gating_features = gating_features.unsqueeze(0).expand(x.shape[0], -1, -1)  # shape: (batch, total_nodes, embed_dim)
            #gating_features = self.fusion_encoder(gating_features)
            #breakpoint()
            gating_features  = gating_features.permute(1, 0, 2) 
            img_tokens  = img_tokens.permute(1, 0, 2) 
            transformer_output = self.cross_block(gating_features, img_tokens)
            transformer_output = transformer_output.permute(1, 0, 2)
        else:
            raise f"{self.attention_neuro_sync} is not supported"
        
        return transformer_output
    
    
    def forward(self, x):
        #breakpoint()
        """
        x: (B, 3, 32, 32)
        Returns: (logits, total_loss)
        """
        self.activations = {}
        info = {}
        if not self.simplified:
            transformer_output = self.forward_neuro_sync(x)
        # ---- B) Decode gating for conv layers.
        miu_list = []
        sigma_list = []
        c_conv_list = []         # list of gating values for conv layers.
        dynamic_vals_list = []   # list of neuron-specific dynamic values.
        ema_conv_list = []       # list of neuron-specific EMA factors.
        modulation_list = []     # list of neuron-specific modulation factors.
        bound_loss = 0.0
        uniform_loss = 0.0
        variance_loss = 0.0
        batch_size = x.shape[0]
        #conv_layer_indices = [0, 1, 2, 3]  # corresponding to conv1..conv4 (global layer indices)
        layer_indices = range(len(self.layer_channels) - 1)  # 0, 1, 2, 3
        for layer_idx in layer_indices:
            start = self.offsets[layer_idx]
            end = start + self.layer_channels[layer_idx]
            if not self.simplified:
                current_embedding = transformer_output[:, start:end]  # (n_neurons, embed_dim)
                gating_out = self.decoder_conv(current_embedding)  # (n_neurons, 5)
                c_vals = gating_out[:, :, 0] * int(self.AL)
                dynamic_vals = F.softplus(gating_out[:, :, 1] * 0.5) * int(self.SR)
                sigma_vals = F.softplus(gating_out[:, :, 2] * 0.5)
                miu_vals = gating_out[:, :, 3] * int(self.ARM)
                ema_vals = F.softplus(gating_out[:, :, 4] * 0.5) * int(self.WC)
                modulation = gating_out[:, :, 5] * int(self.SM)
            else:
                c_vals = self.AL_params[start:end].unsqueeze(0)
                dynamic_vals = F.softplus(self.SR_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)
                sigma_vals = F.softplus(self.AL_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1) * 0.0
                miu_vals = self.ARM_params[start:end].unsqueeze(0).repeat(batch_size, 1)
                ema_vals = F.softplus(self.WC_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)
                modulation = self.SM_params[start:end].unsqueeze(0).repeat(batch_size, 1)
           
            miu_list.append(miu_vals)
            sigma_list.append(sigma_vals)
            c_conv_list.append(c_vals)
            dynamic_vals_list.append(dynamic_vals)
            ema_conv_list.append(ema_vals)
            modulation_list.append(modulation)
            
            if not self.simplified:
                uniform_loss += ext_histogram_divergence_loss(dynamic_vals, 10, torch.min(dynamic_vals).item(),
                                                        torch.max(dynamic_vals).item(), sigma=(torch.max(dynamic_vals).item() - torch.min(dynamic_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(ema_vals, 10, torch.min(ema_vals).item(),
                                                        torch.max(ema_vals).item(), sigma=(torch.max(ema_vals).item() - torch.min(ema_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(modulation, 10, torch.min(modulation).item(),
                                                        torch.max(modulation).item(), sigma=(torch.max(modulation).item() - torch.min(modulation).item())/10)
                uniform_loss += ext_histogram_divergence_loss(miu_vals, 10, torch.min(miu_vals).item(),
                                                            torch.max(miu_vals).item(), sigma=(torch.max(miu_vals).item() - torch.min(miu_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(c_vals, 10, torch.min(c_vals).item(),
                                                            torch.max(c_vals).item(), sigma=(torch.max(c_vals).item() - torch.min(c_vals).item())/10)

                # uniform_loss += ext_histogram_divergence_loss(dynamic_vals, 10, -1, 1, sigma=2/10)
                # uniform_loss += ext_histogram_divergence_loss(ema_vals, 10, -1 , 1, sigma=2/10)
                # uniform_loss += ext_histogram_divergence_loss(modulation, 10, -1, 1, sigma=2/10)
                # uniform_loss += ext_histogram_divergence_loss(miu_vals, 10, -1, 1, sigma=2/10)
                # uniform_loss += ext_histogram_divergence_loss(c_vals, 10, -1, 1, sigma=2/10)

                bound_loss += ext_soft_max(dynamic_vals) + ext_soft_min(dynamic_vals)
                bound_loss += ext_soft_max(ema_vals) + ext_soft_min(ema_vals)
                bound_loss += ext_soft_max(miu_vals) + ext_soft_min(miu_vals)
                bound_loss += ext_soft_max(c_vals) + ext_soft_min(c_vals)
                bound_loss += ext_soft_max(modulation) + ext_soft_min(modulation)

        # ---- C) Decode gating for the output (fc) layer.
        last_layer_idx = len(self.layer_channels) - 1  # fc layer is the last in the list.
        start = self.offsets[last_layer_idx]
        end = start + self.layer_channels[last_layer_idx]
        if self.last_layer_act:
            if not self.simplified:
                fc_embedding = transformer_output[:, start:end]
                fc_gating = self.decoder_conv(fc_embedding)
                fc_c_vals = fc_gating[:, :, 0] * int(self.AL)
                fc_dynamic_vals = F.softplus(fc_gating[:, :, 1] * 0.5)  * int(self.SR)
                fc_sigma_vals = F.softplus(fc_gating[:, :, 2] * 0.5) 
                fc_miu_vals = fc_gating[:, :, 3]  * int(self.ARM)
                fc_ema_vals = F.softplus(fc_gating[:, :, 4] * 0.5) * int(self.WC)
                fc_modulation = fc_gating[:, :, 5] * int(self.SM)
            else:
                c_vals = self.AL_params[start:end].unsqueeze(0).repeat(batch_size, 1)
                dynamic_vals = F.softplus(self.SR_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)
                sigma_vals = F.softplus(self.AL_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1) * 0.0
                miu_vals = self.ARM_params[start:end].unsqueeze(0).repeat(batch_size, 1)
                ema_vals = F.softplus(self.WC_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)
                modulation = self.SM_params[start:end].unsqueeze(0).repeat(batch_size, 1)
           
            miu_list.append(fc_miu_vals)
            sigma_list.append(fc_sigma_vals)
            c_conv_list.append(fc_c_vals)
            dynamic_vals_list.append(fc_dynamic_vals)
            ema_conv_list.append(fc_ema_vals)
            modulation_list.append(fc_modulation)
            
            if not self.simplified:
                uniform_loss += ext_histogram_divergence_loss(fc_dynamic_vals, 10, torch.min(fc_dynamic_vals).item(),
                                                        torch.max(fc_dynamic_vals).item(), sigma=(torch.max(fc_dynamic_vals).item() - torch.min(fc_dynamic_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(fc_ema_vals, 10, torch.min(fc_ema_vals).item(),
                                                        torch.max(fc_ema_vals).item(), sigma=(torch.max(fc_ema_vals).item() - torch.min(fc_ema_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(fc_modulation, 10, torch.min(fc_modulation).item(),
                                                        torch.max(fc_modulation).item(), sigma=(torch.max(fc_modulation).item() - torch.min(fc_modulation).item())/10)
                uniform_loss += ext_histogram_divergence_loss(fc_miu_vals, 10, torch.min(fc_miu_vals).item(),
                                                            torch.max(fc_miu_vals).item(), sigma=(torch.max(fc_miu_vals).item() - torch.min(fc_miu_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(fc_c_vals, 10, torch.min(fc_c_vals).item(),
                                                            torch.max(fc_c_vals).item(), sigma=(torch.max(fc_c_vals).item() - torch.min(fc_c_vals).item())/10)

                bound_loss += ext_soft_max(torch.abs(fc_dynamic_vals)) #+ ext_soft_min(fc_dynamic_vals)
                bound_loss += ext_soft_max(torch.abs(fc_ema_vals)) #+ ext_soft_min(fc_ema_vals)
                bound_loss += ext_soft_max(torch.abs(fc_miu_vals)) #+ ext_soft_min(fc_miu_vals)
                bound_loss += ext_soft_max(torch.abs(fc_c_vals)) #+ ext_soft_min(fc_c_vals)
                bound_loss += ext_soft_max(torch.abs(fc_modulation)) #+ ext_soft_min(fc_modulation)
        else:
            if not self.simplified:
                fc_embedding = transformer_output[:, start:end]
                fc_gating = self.decoder_fc(fc_embedding)
                fc_dynamic_vals = F.softplus(fc_gating[:, :, 0] * 0.5)  * int(self.SR)
                fc_ema_vals = F.softplus(fc_gating[:, :, 1] * 0.5) * int(self.WC)
            else:
                fc_dynamic_vals = F.softplus(self.SR_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)
                fc_ema_vals = F.softplus(self.WC_params[start:end] * 0.5).unsqueeze(0).repeat(batch_size, 1)

            if not self.simplified:
                uniform_loss += ext_histogram_divergence_loss(fc_dynamic_vals, 10, torch.min(fc_dynamic_vals).item(),
                                                            torch.max(fc_dynamic_vals).item(), sigma=(torch.max(fc_dynamic_vals).item() - torch.min(fc_dynamic_vals).item())/10)
                uniform_loss += ext_histogram_divergence_loss(fc_ema_vals, 10, torch.min(fc_ema_vals).item(),
                                                            torch.max(fc_ema_vals).item(), sigma=(torch.max(fc_ema_vals).item() - torch.min(fc_ema_vals).item())/10)
                
                # uniform_loss += ext_histogram_divergence_loss(fc_dynamic_vals, 10, -1, 1, sigma=2/10)
                # uniform_loss += ext_histogram_divergence_loss(fc_ema_vals, 10, -1, 1, sigma=2/10)
                
                bound_loss += ext_soft_max(torch.abs(fc_dynamic_vals)) #+ ext_soft_min(fc_dynamic_vals)
                bound_loss += ext_soft_max(torch.abs(fc_ema_vals)) #+ ext_soft_min(fc_ema_vals)


       
        # ---- D) Update CNN branch parameters (conv and fc) in a differentiable manner.
        updated_cnn_params = {}
        for name, param in self.named_parameters():
            if not (name.startswith("conv") or name.startswith("fc")):
                continue
            init_val = self._initial_params[name]
           
            if name.startswith("conv"):
                try:
                    layer_idx = int(name[4]) - 1 # e.g., "conv1.weight" -> 1
                except Exception:
                    layer_idx = None
                if layer_idx is not None:
                    # conv layer i corresponds to conv_layer_indices[i-1]
                    dyn = dynamic_vals_list[layer_idx]  # shape: (n_neurons,)
                    ema_factor = ema_conv_list[layer_idx] # shape: (n_neurons,)
                    modulation_factor = modulation_list[layer_idx]
                    n_neurons = self.layer_channels[layer_idx]
                else:
                    dyn = None
            elif name.startswith("fc"):
                layer_idx = int(name[2]) - 1 # e.g., "fc1.weight" -> 1
                layer_idx += len(self.cnn_channels)
                if layer_idx < len(self.layer_channels) - 1:
                    dyn = dynamic_vals_list[layer_idx]  # shape: (n_neurons,)
                    ema_factor = ema_conv_list[layer_idx] # shape: (n_neurons,)
                    modulation_factor = modulation_list[layer_idx]
                    n_neurons = self.layer_channels[layer_idx]
                else:
                    dyn = fc_dynamic_vals   # shape: (n_fc,)
                    ema_factor = fc_ema_vals
                    if self.last_layer_act:
                        modulation_factor = fc_modulation
                    else:
                        modulation_factor = torch.zeros_like(fc_dynamic_vals).to(self.device)
                n_neurons = self.layer_channels[layer_idx]
            else:
                print(f"Unexpected parameter name: {name}")
                dyn = None
           
            if dyn is not None and param.shape[0] == n_neurons:
                batch_size = x.shape[0]
                new_shape = [batch_size, n_neurons] + [1] * (param.dim() - 1)
                dynamic_factor = dyn.view(*new_shape)
                ema_factor = ema_factor.view(*new_shape)
                modulation_factor = modulation_factor.view(*new_shape)
            else:
                #print(f"Skipping parameter {name} due to missing dynamic values.")
                dynamic_factor = dyn.mean() if dyn is not None else 0.0
                ema_factor = ema_factor.mean() if ema_factor is not None else 0.0
                modulation_factor = modulation_factor.mean() if modulation_factor is not None else 0.0
           
            ema_val = self._ema_params[name].to(self.device).clone().detach().unsqueeze(0)
           
            info = {**info, f'alpha/SM_{name}': wandb.Histogram(modulation_factor.view(-1).detach().cpu().numpy()), 
                    f'alpha/WC_{name}': wandb.Histogram(ema_factor.view(-1).detach().cpu().numpy())}
            
            param = param.unsqueeze(0).repeat_interleave(x.shape[0], dim=0)  # shape: (B, n_neurons, ...)
            if self.SM_detach:
                updated_cnn_params[name] = param + param.clone().detach() * modulation_factor + ema_val * ema_factor + init_val.to(self.device) * dynamic_factor
            else:
                updated_cnn_params[name] = (1 + modulation_factor) * param + ema_val * ema_factor + init_val.to(self.device) * dynamic_factor  

        logits, acts, info = self.forward_net(x, c_conv_list, miu_list, sigma_list, updated_cnn_params, info=info)

        if self.track_params:
            # Log gating values.
            c_conv_cat = torch.cat([tensor.mean(dim=0) for tensor in c_conv_list]).cpu()
            dynamic_cat = torch.cat([tensor.mean(dim=0) for tensor in dynamic_vals_list]).cpu()
            sigma_cat = torch.cat([tensor.mean(dim=0) for tensor in sigma_list]).cpu()
            miu_cat = torch.cat([tensor.mean(dim=0) for tensor in miu_list]).cpu()
            ema_conv_cat = torch.cat([tensor.mean(dim=0) for tensor in ema_conv_list]).cpu()
            modulation_cat = torch.cat([tensor.mean(dim=0) for tensor in modulation_list]).cpu()
        
            self.stored_c_vals.append(c_conv_cat)
            self.stored_dynamic_vals.append(dynamic_cat)
            self.stored_sigma_vals.append(sigma_cat)
            self.stored_miu_vals.append(miu_cat)
            self.stored_ema_vals.append(ema_conv_cat)
            self.stored_modulation_vals.append(modulation_cat)

        if self.training:
            self.update_ema()

            if not self.simplified:
                # NEW: Update the per-neuron EMA buffers for activation and weight norm.
                #conv_layer_indices = [0, 1, 2, 3]  # global layer indices for conv layers
                conv_layer_indices = range(len(self.layer_channels) - 1)  # 0, 1, 2, 3
                with torch.no_grad():
                    # Update activation EMA for conv layers and fc.
                    for i, conv_idx in enumerate(conv_layer_indices):
                        start = self.offsets[conv_idx]
                        end = start + self.layer_channels[conv_idx]
                        mean_old_val = self.activationmean_ema[start:end]
                        max_old_val = self.activationmax_ema[start:end]
                        min_old_val = self.activationmin_ema[start:end]
                        std_old_val = self.activationstd_ema[start:end]
                        if i < len(self.cnn_channels):
                            mean_new_val = 0.999 * mean_old_val + 0.001 * acts[i].mean(dim=[0,2,3]).view(-1, 1)
                            max_new_val = 0.999 * max_old_val + 0.001 * torch.amax(acts[i], dim=(0,2,3)).view(-1, 1)
                            min_new_val = 0.999 * min_old_val + 0.001 * torch.amin(acts[i], dim=(0,2,3)).view(-1, 1)
                            std_new_val = 0.999 * std_old_val + 0.001 * acts[i].std(dim=[0,2,3]).view(-1, 1)
                        else:
                            # TODO: check if this is correct.
                            mean_new_val = 0.999 * mean_old_val + 0.001 * acts[i].mean(dim=[0]).view(-1, 1)
                            max_new_val = 0.999 * max_old_val + 0.001 * torch.amax(acts[i], dim=(0)).view(-1, 1)
                            min_new_val = 0.999 * min_old_val + 0.001 * torch.amin(acts[i], dim=(0)).view(-1, 1)
                            if acts[i].shape[0] == 1:
                                std_new_val = 0.999 * std_old_val + 0.001 * acts[i].std(dim=0, unbiased=False).view(-1, 1)
                            else:
                                std_new_val = 0.999 * std_old_val + 0.001 * acts[i].std(dim=0).view(-1, 1)
                        self.activationmean_ema[start:end] = mean_new_val.detach()  # NEW
                        self.activationmax_ema[start:end] = max_new_val.detach()  # NEW
                        self.activationmin_ema[start:end] = min_new_val.detach()  # NEW
                        self.activationstd_ema[start:end] = std_new_val.detach()  # NEW
                    
                
                    # Update weight norm EMA for conv and fc layers.
                    for name, param in updated_cnn_params.items():
                        if name.startswith("conv") or name.startswith("fc"):
                            if name.endswith("weight"):
                                param = param.mean(0)
                                #breakpoint()
                                # For conv: (out_channels, in_channels, kh, kw); for fc: (out_features, in_features)
                                mean_val = param.view(param.shape[0], -1).mean(dim=1).detach()  # NEW
                                max_val = torch.amax(param.view(param.shape[0], -1), dim=1).detach()  # NEW
                                min_val = torch.amin(param.view(param.shape[0], -1), dim=1).detach()  # NEW
                                std_val = param.view(param.shape[0], -1).std(dim=1).detach()  # NEW
                                if name.startswith("conv"):
                                    layer_idx = int(name[4]) - 1  # e.g., conv1 -> 1
                                else:  # fc
                                    layer_idx = int(name[2]) - 1  + len(self.cnn_channels)
                                start = self.offsets[layer_idx]
                                end = start + self.layer_channels[layer_idx]
                                mean_old_val = self.weightmean_ema[start:end]
                                max_old_val = self.weightmax_ema[start:end]
                                min_old_val = self.weightmin_ema[start:end]
                                std_old_val = self.weightstd_ema[start:end]
                                mean_new_val = 0.999 * mean_old_val + 0.001 * mean_val.view(-1, 1)
                                max_new_val = 0.999 * max_old_val + 0.001 * max_val.view(-1, 1)
                                min_new_val = 0.999 * min_old_val + 0.001 * min_val.view(-1, 1)
                                std_new_val = 0.999 * std_old_val + 0.001 * std_val.view(-1, 1)


                                self.weightmean_ema[start:end] = mean_new_val.detach()  # NEW
                                self.weightmax_ema[start:end] = max_new_val.detach()  # NEW
                                self.weightmin_ema[start:end] = min_new_val.detach()  # NEW
                                self.weightstd_ema[start:end] = std_new_val.detach()  # NEW
        if self.simplified:
            uniform_loss = torch.zeros((1)).requires_grad_(True).to(self.device)
            bound_loss = torch.zeros((1, 1)).requires_grad_(True).to(self.device)
            info['uniform_loss'] = uniform_loss
            info['bound_loss'] = bound_loss
    
        info['uniform_loss'] = uniform_loss
        info['bound_loss'] = bound_loss.mean(0)
        info['hidden'] = acts[-2]
        #info['ema_logits'] = ema_logits
        return logits, info


    def forward_net(self, x, c_conv_list, miu_list, sigma_list, updated_cnn_params, info):
        """
        Forward pass for the CNN branch using updated parameters.
        Also computes and returns channel-wise mean activations for conv layers.
        """
        act_means = []  # store activation means per conv layer.
        out = x
        for i in range(len(self.cnn_channels)):
            #breakpoint()
            out = conv2d_per_sample(out,
                            updated_cnn_params[f'conv{i+1}.weight'],
                            updated_cnn_params[f'conv{i+1}.bias'],
                            padding=self.padding[i],
                            stride=self.stride[i])
            out = self.channelwise_activation(out, c_conv_list[i], miu_list[i], sigma_list[i], type='conv')
            if info is not None:
                info = {**info, f'alpha/AL_conv{i+1}': wandb.Histogram(c_conv_list[i].view(-1).detach().cpu().numpy())}
                info = {**info, f'alpha/ARM_conv{i+1}': wandb.Histogram(miu_list[i].view(-1).detach().cpu().numpy())}
            act_means.append(out.detach())
            if self.pooling_type[i] == 'max':
                out = F.max_pool2d(out, self.pooling_kernel[i])
            elif self.pooling_type[i] == 'avg':
                out = F.avg_pool2d(out, self.pooling_kernel[i])
            elif self.pooling_type[i] == 'none':
                out = out
            else:
                raise ValueError(f"Unknown pooling type: {self.pooling_type[i]}")
            
            self.activations[f'conv{i+1}'] = out.detach()  # Store activations for conv layers.
        
        # Flatten the output for the fully connected layer.
        out = out.view(out.size(0), -1)

        # Apply the fully connected layers.
        for i in range(len(self.fc_channels)):
            if i < len(self.fc_channels) - 1 or self.last_layer_act:
                out = linear_per_sample(out,
                                updated_cnn_params[f'fc{i+1}.weight'],
                                updated_cnn_params[f'fc{i+1}.bias'])
                
                out = self.channelwise_activation(out, c_conv_list[i + len(self.cnn_channels)], miu_list[i + len(self.cnn_channels)], sigma_list[i + len(self.cnn_channels)], type='fc')
                act_means.append(out.detach())
                self.activations[f'fc{i+1}'] = out.detach()  # Store activations for fc layers.
                if info is not None:
                    info = {**info, f'alpha/AL_fc{i+1}': wandb.Histogram(c_conv_list[i + len(self.cnn_channels)].view(-1).detach().cpu().numpy())}
                    info = {**info, f'alpha/ARM_fc{i+1}': wandb.Histogram(miu_list[i + len(self.cnn_channels)].view(-1).detach().cpu().numpy())}
            else:
                out = linear_per_sample(out,
                                updated_cnn_params[f'fc{i+1}.weight'],
                                updated_cnn_params[f'fc{i+1}.bias'])
                act_means.append(out.detach())
        
        logits = out
        return logits, act_means, info


    def channelwise_activation(self, x, c_vector, miu, sigma, type):
        """
        Applies gating to the channels of x according to c_vector.
        x: (B, Channels, H, W)
        c_vector: (Channels,)
        """
        if type == 'conv':
            miu = miu.unsqueeze(-1).unsqueeze(-1)
            c_shaped = c_vector.unsqueeze(-1).unsqueeze(-1)
        elif type == 'fc':
            miu = miu
            c_shaped = c_vector
        relu_out = F.relu(x)
        determinstic = relu_out + c_shaped * x * (x < 0).float() # type: ignore
        result = determinstic + miu 
        return result
   
    @staticmethod
    def positional_encoding(x, num_freq):
        """
        x: (N, 2) tensor of base positions.
        Returns: (N, 2 + 4*num_freq) tensor.
        For each coordinate, include the original value and for each frequency i compute:
           sin(2^i * coordinate) and cos(2^i * coordinate)
        """
        encodings = [x]  # include the original (x, y)
        for i in range(num_freq):
            frequency = 2.0 ** i
            encodings.append(torch.sin(frequency * x))
            encodings.append(torch.cos(frequency * x))
        return torch.cat(encodings, dim=-1)
   
    def update_ema(self):
        for name, param in self.named_parameters():
            if not (name.startswith("conv") or name.startswith("fc")):
                continue
            with torch.no_grad():
                self._ema_params[name] = self.ema_decay * self._ema_params[name] + (1 - self.ema_decay) * param.detach().cpu()


    def plot_params(self):
        self.stored_c_vals = torch.stack(self.stored_c_vals, dim=0).mean(dim=0)
        self.stored_dynamic_vals = torch.stack(self.stored_dynamic_vals, dim=0).mean(dim=0)
        self.stored_sigma_vals = torch.stack(self.stored_sigma_vals, dim=0).mean(dim=0)
        self.stored_miu_vals = torch.stack(self.stored_miu_vals, dim=0).mean(dim=0)
        self.stored_ema_vals = torch.stack(self.stored_ema_vals, dim=0).mean(dim=0)
        self.stored_modulation_vals = torch.stack(self.stored_modulation_vals, dim=0).mean(dim=0)
        log_info_neuro_sync = {"dyn_values": wandb.Histogram(self.stored_dynamic_vals.detach().numpy()),
                   "c_values": wandb.Histogram(self.stored_c_vals.detach().numpy()),
                   "sigma_values": wandb.Histogram(self.stored_sigma_vals.detach().numpy()),
                   "miu_values": wandb.Histogram(self.stored_miu_vals.detach().numpy()),
                   "ema_values": wandb.Histogram(self.stored_ema_vals.detach().numpy()),
                   "modulation_values": wandb.Histogram(self.stored_modulation_vals.detach().numpy())}
        self.stored_c_vals = []
        self.stored_dynamic_vals = []
        self.stored_sigma_vals = []
        self.stored_miu_vals = []
        self.stored_ema_vals = []
        self.stored_modulation_vals = []
        return log_info_neuro_sync

    def get_model_weights_l2_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return utils_l2_norm(filtered_params)

    def compute_l1_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return utils_l1_norm(filtered_params)
    
    def compute_l2_norm(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        return new_utils_l2_norm(filtered_params)
    
    def compute_total_params(self):
        filtered_params = (
            (name, param)
            for name, param in self.named_parameters()
            if param.requires_grad and (name.startswith('fc') or name.startswith('conv'))
        )
        total_params = 0
        for name, param in filtered_params:
            if 'layer_norm' not in name and \
                'init_params' not in name and \
                    'original_last_layer_params' not in name:
                    total_params += param.numel()

        return total_params
