"""
Gated Multi-scale Attention (GMA) Module
Time/frequency dual branches with gated fusion for multi-scale temporal analysis
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from typing import Dict, List, Tuple, Optional, Union
import math


class GatedMultiscaleAttention(nn.Module):
    """
    Gated Multi-scale Attention with dual-branch processing
    Combines temporal and frequency domain analysis with learnable gating
    """
    
    def __init__(self, 
                 input_dim: int, 
                 config: Dict,
                 device: Optional[torch.device] = None):
        super().__init__()
        self.input_dim = input_dim
        self.config = config
        self.device = device or torch.device('cpu')
        
        # Multi-scale parameters
        self.scales = config.get('scales', [8, 16, 32, 64])
        self.num_heads = config.get('num_heads', 8)
        self.head_dim = config.get('head_dim', input_dim // self.num_heads)
        
        # Time branch configuration
        self.time_branch_enabled = config.get('time_branch', {}).get('enabled', True)
        self.time_filter_orders = config.get('time_branch', {}).get('filter_orders', [3, 5, 7, 9])
        
        # Frequency branch configuration  
        self.freq_branch_enabled = config.get('frequency_branch', {}).get('enabled', True)
        self.use_fft = config.get('frequency_branch', {}).get('use_fft', True)
        self.use_wavelet = config.get('frequency_branch', {}).get('use_wavelet', True)
        self.wavelet_levels = config.get('frequency_branch', {}).get('wavelet_levels', 4)
        self.fft_bins = config.get('frequency_branch', {}).get('fft_bins', 64)
        
        # Filter bank
        self.filter_bank_config = config.get('filter_bank', {})
        self.num_filters = self.filter_bank_config.get('num_filters', 8)
        self.filter_sizes = self.filter_bank_config.get('filter_sizes', [3, 5, 7, 9])
        self.dilation_rates = self.filter_bank_config.get('dilation_rates', [1, 2, 4, 8])
        
        # Gating mechanism
        self.gating_config = config.get('gating', {})
        self.gating_depth = self.gating_config.get('depth', 3)
        self.gating_hidden_dim = self.gating_config.get('hidden_dim', 128)
        
        # Build components
        self._build_time_branch()
        self._build_frequency_branch() 
        self._build_filter_bank()
        self._build_gating_mechanism()
        self._build_fusion_layer()
        
        # Initialize weights with stable values
        self._initialize_weights()
    
    def _initialize_weights(self):
        """Initialize weights with stable values to prevent NaN during training"""
        for name, module in self.named_modules():
            if isinstance(module, nn.Conv1d):

                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.Linear):

                nn.init.xavier_uniform_(module.weight, gain=0.1)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.LayerNorm):
                # LayerNorm
                nn.init.constant_(module.weight, 1.0)
                if module.bias is not None:
                    nn.init.constant_(module.bias, 0.0)
            elif isinstance(module, nn.MultiheadAttention):

                for param in module.parameters():
                    if param.dim() > 1:
                        nn.init.xavier_uniform_(param, gain=0.1)
                    else:
                        nn.init.constant_(param, 0.0)
    
    def check_weights(self):
        """Check if any weights contain NaN or inf values"""
        nan_found = False
        inf_found = False
        for name, param in self.named_parameters():
            if torch.isnan(param).any():
                print(f)
                nan_found = True
            if torch.isinf(param).any():
                print(f)
                inf_found = True
        return nan_found or inf_found
        
    def _build_time_branch(self):
        """Build temporal convolution branch"""
        if not self.time_branch_enabled:
            self.time_branch = None
            return
            
        # Multi-scale temporal convolutions
        self.temporal_convs = nn.ModuleList()
        for filter_order in self.time_filter_orders:
            conv_layer = nn.Conv1d(
                in_channels=self.input_dim,
                out_channels=self.input_dim,
                kernel_size=filter_order,
                padding=filter_order // 2,
                groups=1
            )
            self.temporal_convs.append(conv_layer)
        
        # Temporal attention
        self.temporal_attention = nn.MultiheadAttention(
            embed_dim=self.input_dim,
            num_heads=self.num_heads,
            dropout=self.config.get('attention', {}).get('dropout', 0.1),
            batch_first=True
        )
        
        # Temporal normalization
        self.temporal_norm = nn.LayerNorm(self.input_dim)
        
    def _build_frequency_branch(self):
        """Build frequency domain analysis branch"""
        if not self.freq_branch_enabled:
            self.frequency_branch = None
            return
            
        # FFT processing
        if self.use_fft:
            self.fft_projection = nn.Linear(self.fft_bins, self.input_dim)
            
        # Wavelet processing  
        if self.use_wavelet:
            self.wavelet_projections = nn.ModuleList([
                nn.Linear(self.input_dim, self.input_dim) 
                for _ in range(self.wavelet_levels)
            ])
            
        # Frequency attention
        self.frequency_attention = nn.MultiheadAttention(
            embed_dim=self.input_dim,
            num_heads=self.num_heads,
            dropout=self.config.get('attention', {}).get('dropout', 0.1),
            batch_first=True
        )
        
        # Frequency normalization
        self.frequency_norm = nn.LayerNorm(self.input_dim)
        
    def _build_filter_bank(self):
        """Build multi-scale filter bank"""
        self.filter_bank = nn.ModuleList()
        
        for i, (filter_size, dilation) in enumerate(zip(self.filter_sizes, self.dilation_rates)):
            conv = nn.Conv1d(
                in_channels=self.input_dim,
                out_channels=self.input_dim,
                kernel_size=filter_size,
                dilation=dilation,
                padding=(filter_size - 1) * dilation // 2,
                groups=self.input_dim // 4  # Depthwise-like convolution
            )
            self.filter_bank.append(conv)
            
        # Filter bank fusion
        self.filter_fusion = nn.Conv1d(
            in_channels=self.input_dim * len(self.filter_sizes),
            out_channels=self.input_dim,
            kernel_size=1
        )
        
    def _build_gating_mechanism(self):
        """Build learnable gating mechanism"""
        # Input gate for time vs frequency balance
        self.input_gate = nn.Sequential(
            nn.Linear(self.input_dim * 2, self.gating_hidden_dim),
            nn.Tanh(),
            nn.Linear(self.gating_hidden_dim, 1),
            nn.Sigmoid()
        )
        
        # Scale-specific gates
        self.scale_gates = nn.ModuleList()
        for scale in self.scales:
            gate = nn.Sequential(
                nn.Linear(self.input_dim, self.gating_hidden_dim),
                nn.ReLU(),
                nn.Linear(self.gating_hidden_dim, 1),
                nn.Sigmoid()
            )
            self.scale_gates.append(gate)
            
        # Feature gates for different filter outputs
        self.feature_gates = nn.ModuleList()
        for _ in range(len(self.filter_sizes)):
            gate = nn.Sequential(
                nn.Linear(self.input_dim, self.input_dim),
                nn.Sigmoid()
            )
            self.feature_gates.append(gate)
            
    def _build_fusion_layer(self):
        """Build final fusion layer"""
        # Calculate total input dimension for fusion
        fusion_input_dim = 0
        if self.time_branch_enabled:
            fusion_input_dim += self.input_dim
        if self.freq_branch_enabled:
            fusion_input_dim += self.input_dim
        fusion_input_dim += self.input_dim  # Filter bank output
        
        self.fusion_layer = nn.Sequential(
            nn.Linear(fusion_input_dim, self.input_dim * 2),
            nn.GELU(),
            nn.Dropout(self.config.get('fusion', {}).get('dropout', 0.1)),
            nn.Linear(self.input_dim * 2, self.input_dim),
            nn.LayerNorm(self.input_dim)
        )
        
    def process_time_branch(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process input through temporal branch
        
        Args:
            x: Input tensor [batch, seq_len, features]
            
        Returns:
            time_features: Temporal features [batch, seq_len, features]
        """
        if not self.time_branch_enabled:
            return None
            
        batch_size, seq_len, features = x.shape
        
        # NaN
        if torch.isnan(x).any():
            print(f)
            print(f)
            x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        

        if torch.isinf(x).any():
            print(f)
            print(f)
            x = torch.where(torch.isinf(x), torch.zeros_like(x), x)
        

        if torch.allclose(x, torch.zeros_like(x)):
            print(f)
            return torch.zeros_like(x)
        

        x = torch.clamp(x, min=-50.0, max=50.0)
        


        valid_lengths = []
        for b in range(batch_size):

            non_zero_mask = torch.any(x[b] != 0, dim=-1)  # [seq_len]
            if non_zero_mask.any():
                valid_len = torch.nonzero(non_zero_mask, as_tuple=True)[0][-1].item() + 1
            else:
                valid_len = 1
            valid_lengths.append(valid_len)
        
        try:
            # Apply temporal convolutions
            x_transposed = x.transpose(1, 2)  # [batch, features, seq_len]
            temporal_outputs = []
            
            for i, conv in enumerate(self.temporal_convs):
                conv_out = conv(x_transposed)  # [batch, features, seq_len]
                
                # NaN
                if torch.isnan(conv_out).any() or torch.isinf(conv_out).any():
                    print(f)
                    print(f)
                    print(f)
                    

                    if hasattr(conv, 'weight'):
                        weight_stats = f"weight: min={conv.weight.min():.6f}, max={conv.weight.max():.6f}, mean={conv.weight.mean():.6f}"
                        print(f)
                        if torch.isnan(conv.weight).any():
                            print(f)
                        if torch.isinf(conv.weight).any():
                            print(f)
                    
                    conv_out = torch.where(torch.isnan(conv_out) | torch.isinf(conv_out), 
                                         torch.zeros_like(conv_out), conv_out)
                
                temporal_outputs.append(conv_out)
            
            # Combine temporal convolution outputs
            temporal_combined = torch.stack(temporal_outputs, dim=0).mean(dim=0)  # [batch, features, seq_len]
            temporal_combined = temporal_combined.transpose(1, 2)  # [batch, seq_len, features]
            
            # NaN
            if torch.isnan(temporal_combined).any() or torch.isinf(temporal_combined).any():
                print(f)
                temporal_combined = torch.where(torch.isnan(temporal_combined) | torch.isinf(temporal_combined), 
                                              torch.zeros_like(temporal_combined), temporal_combined)
            

            temporal_combined = torch.clamp(temporal_combined, min=-100.0, max=100.0)
            
            # Apply temporal attention
            try:
                temporal_attended, attn_weights = self.temporal_attention(temporal_combined, temporal_combined, temporal_combined)
                
                # NaN
                if torch.isnan(temporal_attended).any() or torch.isinf(temporal_attended).any():
                    print(f)
                    print(f)
                    print(f)
                    

                    if attn_weights is not None:
                        print(f)
                        if torch.isnan(attn_weights).any():
                            print(f)
                        if torch.isinf(attn_weights).any():
                            print(f)
                    
                    temporal_attended = torch.where(torch.isnan(temporal_attended) | torch.isinf(temporal_attended), 
                                                  torch.zeros_like(temporal_attended), temporal_attended)
            except Exception as e:
                print(f)
                temporal_attended = temporal_combined
            
            # Residual connection and normalization
            residual_input = temporal_attended + x
            
            # NaN
            if torch.isnan(residual_input).any() or torch.isinf(residual_input).any():
                print(f)
                residual_input = torch.where(torch.isnan(residual_input) | torch.isinf(residual_input), 
                                           torch.zeros_like(residual_input), residual_input)
            
            try:
                temporal_output = self.temporal_norm(residual_input)
                
                # NaN
                if torch.isnan(temporal_output).any() or torch.isinf(temporal_output).any():
                    print(f)
                    print(f)
                    print(f)
                    

                    if hasattr(self.temporal_norm, 'weight'):
                        print(f)
                        if torch.isnan(self.temporal_norm.weight).any():
                            print(f)
                    if hasattr(self.temporal_norm, 'bias') and self.temporal_norm.bias is not None:
                        print(f)
                        if torch.isnan(self.temporal_norm.bias).any():
                            print(f)
                    
                    temporal_output = torch.where(torch.isnan(temporal_output) | torch.isinf(temporal_output), 
                                                torch.zeros_like(temporal_output), temporal_output)
            except Exception as e:
                print(f)
                temporal_output = residual_input
            
        except Exception as e:
            print(f)
            temporal_output = x
        
        return temporal_output
    
    def process_frequency_branch(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process input through frequency branch
        
        Args:
            x: Input tensor [batch, seq_len, features]
            
        Returns:
            freq_features: Frequency features [batch, seq_len, features]
        """
        if not self.freq_branch_enabled:
            return None
            
        batch_size, seq_len, features = x.shape
        frequency_features = []
        
        # NaN
        if torch.isnan(x).any():
            print(f)
            x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        

        if torch.allclose(x, torch.zeros_like(x)):
            print(f)
            return torch.zeros_like(x)
        
        # FFT processing
        if self.use_fft:
            try:

                x_clamped = torch.clamp(x, min=-50.0, max=50.0)
                
                # Apply FFT along the sequence dimension

                if x_clamped.dtype == torch.float16:
                    x_clamped = x_clamped.float()
                x_fft = torch.fft.fft(x_clamped, dim=1)  # [batch, seq_len, features]
                

                if torch.isnan(x_fft).any() or torch.isinf(x_fft).any():
                    print(f)
                    print(f)

                    frequency_features.append(x)
                else:
                    x_fft_mag = torch.abs(x_fft)  # Magnitude spectrum
                    

                    x_fft_mag = torch.clamp(x_fft_mag, min=1e-8, max=1e8)
                    
                    # Reduce to specified number of bins
                    if seq_len > self.fft_bins:
                        # Downsample frequency bins
                        indices = torch.linspace(0, seq_len - 1, self.fft_bins, dtype=torch.long, device=x.device)
                        x_fft_mag = x_fft_mag[:, indices, :]
                    elif seq_len < self.fft_bins:
                        # Upsample frequency bins
                        x_fft_mag = F.interpolate(
                            x_fft_mag.transpose(1, 2), 
                            size=self.fft_bins, 
                            mode='linear'
                        ).transpose(1, 2)
                    
                    # Project to original sequence length
                    fft_features = self.fft_projection(x_fft_mag.transpose(1, 2)).transpose(1, 2)
                    if fft_features.shape[1] != seq_len:
                        fft_features = F.interpolate(
                            fft_features.transpose(1, 2),
                            size=seq_len,
                            mode='linear'
                        ).transpose(1, 2)
                    
                    # NaN
                    if torch.isnan(fft_features).any() or torch.isinf(fft_features).any():
                        print(f)
                        print(f)
                        fft_features = torch.where(torch.isnan(fft_features) | torch.isinf(fft_features), 
                                                 torch.zeros_like(fft_features), fft_features)
                    
                    frequency_features.append(fft_features)
                    
            except Exception as e:
                print(f)

                frequency_features.append(x)
        
        # Wavelet processing
        if self.use_wavelet:
            try:

                x_wavelet = torch.clamp(x, min=-50.0, max=50.0)
                wavelet_features = self._apply_wavelet_transform(x_wavelet)
                
                # NaN
                if torch.isnan(wavelet_features).any() or torch.isinf(wavelet_features).any():
                    print(f)
                    print(f)
                    wavelet_features = torch.where(torch.isnan(wavelet_features) | torch.isinf(wavelet_features), 
                                                 torch.zeros_like(wavelet_features), wavelet_features)
                frequency_features.append(wavelet_features)
            except Exception as e:
                print(f)
        
        # Combine frequency features
        if frequency_features:
            freq_combined = torch.stack(frequency_features, dim=0).mean(dim=0)
            # NaN
            if torch.isnan(freq_combined).any():
                print(f)
                freq_combined = torch.where(torch.isnan(freq_combined), torch.zeros_like(freq_combined), freq_combined)
        else:
            freq_combined = x
        
        # Apply frequency attention
        try:

            if torch.isnan(freq_combined).any():
                print(f)
                print(f)
                freq_combined = torch.where(torch.isnan(freq_combined), torch.zeros_like(freq_combined), freq_combined)
            

            if torch.isinf(freq_combined).any():
                print(f)
                freq_combined = torch.where(torch.isinf(freq_combined), torch.zeros_like(freq_combined), freq_combined)
            

            freq_combined = torch.clamp(freq_combined, min=-100.0, max=100.0)
            
            freq_attended, _ = self.frequency_attention(freq_combined, freq_combined, freq_combined)
            
            # NaN
            if torch.isnan(freq_attended).any():
                print(f)
                print(f)
                print(f)

                freq_attended = freq_combined
                
        except Exception as e:
            print(f)
            freq_attended = freq_combined
        
        # Residual connection and normalization
        try:
            freq_output = self.frequency_norm(freq_attended + x)
            # NaN
            if torch.isnan(freq_output).any():
                print(f)
                freq_output = torch.where(torch.isnan(freq_output), torch.zeros_like(freq_output), freq_output)
        except Exception as e:
            print(f)
            freq_output = freq_attended
        
        return freq_output
    
    def _apply_wavelet_transform(self, x: torch.Tensor) -> torch.Tensor:
        """
        Apply simplified wavelet transform
        
        Args:
            x: Input tensor [batch, seq_len, features]
            
        Returns:
            wavelet_features: Wavelet-transformed features
        """

        if torch.isnan(x).any() or torch.isinf(x).any():
            print(f)
            return torch.zeros_like(x)
        
        # Simplified wavelet transform using convolutions
        # This is a learnable approximation of discrete wavelet transform
        
        wavelet_outputs = []
        current_input = x
        
        try:
            for level, projection in enumerate(self.wavelet_projections):
                # Downsampling (approximation coefficients)
                if current_input.shape[1] > 1:
                    # Apply projection
                    projected = projection(current_input)
                    

                    if torch.isnan(projected).any() or torch.isinf(projected).any():
                        print(f)
                        projected = torch.where(torch.isnan(projected) | torch.isinf(projected), 
                                              torch.zeros_like(projected), projected)
                    
                    # Downsample by 2
                    downsampled = F.avg_pool1d(
                        projected.transpose(1, 2), 
                        kernel_size=2, 
                        stride=2
                    ).transpose(1, 2)
                    

                    if torch.isnan(downsampled).any() or torch.isinf(downsampled).any():
                        print(f)
                        downsampled = torch.where(torch.isnan(downsampled) | torch.isinf(downsampled), 
                                                torch.zeros_like(downsampled), downsampled)
                    
                    wavelet_outputs.append(downsampled)
                    current_input = downsampled
                else:
                    break
            
            # Reconstruct to original length
            if wavelet_outputs:
                # Start with the smallest scale
                reconstructed = wavelet_outputs[-1]
                
                # Progressively upsample
                for i in range(len(wavelet_outputs) - 2, -1, -1):
                    target_length = wavelet_outputs[i].shape[1]
                    reconstructed = F.interpolate(
                        reconstructed.transpose(1, 2),
                        size=target_length,
                        mode='linear'
                    ).transpose(1, 2)
                    

                    if torch.isnan(reconstructed).any() or torch.isinf(reconstructed).any():
                        print(f)
                        reconstructed = torch.where(torch.isnan(reconstructed) | torch.isinf(reconstructed), 
                                                  torch.zeros_like(reconstructed), reconstructed)
                    
                    # Add residual connection
                    reconstructed = reconstructed + wavelet_outputs[i]
                    

                    if torch.isnan(reconstructed).any() or torch.isinf(reconstructed).any():
                        print(f)
                        reconstructed = torch.where(torch.isnan(reconstructed) | torch.isinf(reconstructed), 
                                                  torch.zeros_like(reconstructed), reconstructed)
            
                # Final interpolation to match input length
                if reconstructed.shape[1] != x.shape[1]:
                    reconstructed = F.interpolate(
                        reconstructed.transpose(1, 2),
                        size=x.shape[1],
                        mode='linear'
                    ).transpose(1, 2)
                    

                    if torch.isnan(reconstructed).any() or torch.isinf(reconstructed).any():
                        print(f)
                        reconstructed = torch.where(torch.isnan(reconstructed) | torch.isinf(reconstructed), 
                                                  torch.zeros_like(reconstructed), reconstructed)
            else:
                reconstructed = x
                
        except Exception as e:
            print(f)
            reconstructed = x
            
        return reconstructed
    
    def process_filter_bank(self, x: torch.Tensor) -> torch.Tensor:
        """
        Process input through multi-scale filter bank
        
        Args:
            x: Input tensor [batch, seq_len, features]
            
        Returns:
            filtered_features: Multi-scale filtered features
        """
        batch_size, seq_len, features = x.shape
        
        # NaN
        if torch.isnan(x).any():
            print(f)
            x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        

        if torch.allclose(x, torch.zeros_like(x)):
            print(f)
            return torch.zeros_like(x)
        
        try:
            x_transposed = x.transpose(1, 2)  # [batch, features, seq_len]
            
            # Apply filters
            filter_outputs = []
            for i, filter_conv in enumerate(self.filter_bank):
                filtered = filter_conv(x_transposed)  # [batch, features, seq_len]
                
                # NaN
                if torch.isnan(filtered).any() or torch.isinf(filtered).any():
                    print(f)
                    filtered = torch.where(torch.isnan(filtered) | torch.isinf(filtered), 
                                         torch.zeros_like(filtered), filtered)
                
                # Apply feature gate
                filtered_features = filtered.transpose(1, 2)  # [batch, seq_len, features]
                gate_weights = self.feature_gates[i](filtered_features)
                
                # NaN
                if torch.isnan(gate_weights).any() or torch.isinf(gate_weights).any():
                    print(f)
                    gate_weights = torch.where(torch.isnan(gate_weights) | torch.isinf(gate_weights), 
                                             torch.ones_like(gate_weights), gate_weights)
                
                gated_filtered = filtered_features * gate_weights
                
                # NaN
                if torch.isnan(gated_filtered).any() or torch.isinf(gated_filtered).any():
                    print(f)
                    gated_filtered = torch.where(torch.isnan(gated_filtered) | torch.isinf(gated_filtered), 
                                               torch.zeros_like(gated_filtered), gated_filtered)
                
                filter_outputs.append(gated_filtered.transpose(1, 2))  # Back to [batch, features, seq_len]
            
            # Concatenate all filter outputs
            concatenated = torch.cat(filter_outputs, dim=1)  # [batch, features * num_filters, seq_len]
            
            # NaN
            if torch.isnan(concatenated).any() or torch.isinf(concatenated).any():
                print(f)
                concatenated = torch.where(torch.isnan(concatenated) | torch.isinf(concatenated), 
                                         torch.zeros_like(concatenated), concatenated)
            
            # Fuse using 1x1 convolution
            fused = self.filter_fusion(concatenated)  # [batch, features, seq_len]
            fused = fused.transpose(1, 2)  # [batch, seq_len, features]
            
            # NaN
            if torch.isnan(fused).any() or torch.isinf(fused).any():
                print(f)
                fused = torch.where(torch.isnan(fused) | torch.isinf(fused), 
                                  torch.zeros_like(fused), fused)
            
        except Exception as e:
            print(f)
            fused = x
        
        return fused
    
    def apply_gating(self, time_features: Optional[torch.Tensor],
                    freq_features: Optional[torch.Tensor],
                    filter_features: torch.Tensor) -> torch.Tensor:
        """
        Apply gating mechanism to combine features
        
        Args:
            time_features: Temporal features
            freq_features: Frequency features  
            filter_features: Filter bank features
            
        Returns:
            gated_features: Gated and combined features
        """
        features_list = []
        
        # Time-frequency gating
        if time_features is not None and freq_features is not None:
            # Compute gate weight for time vs frequency balance
            combined_input = torch.cat([time_features, freq_features], dim=-1)
            time_freq_gate = self.input_gate(combined_input)  # [batch, seq_len, 1]
            
            # Apply gating
            gated_time = time_features * time_freq_gate
            gated_freq = freq_features * (1 - time_freq_gate)
            
            features_list.extend([gated_time, gated_freq])
        elif time_features is not None:
            features_list.append(time_features)
        elif freq_features is not None:
            features_list.append(freq_features)
        
        # Add filter features
        features_list.append(filter_features)
        
        return features_list
    
    def forward(self, x: torch.Tensor, return_attention: bool = False) -> Dict[str, torch.Tensor]:
        """
        Forward pass through GMA
        
        Args:
            x: Input tensor [batch, seq_len, features]
            return_attention: Whether to return attention weights
            
        Returns:
            output: Dictionary containing processed features and metadata
        """
        batch_size, seq_len, features = x.shape
        

        if self.check_weights():
            print(f)

            self._initialize_weights()
        
        # NaN
        if torch.isnan(x).any():
            print(f)
            print(f)
            x = torch.where(torch.isnan(x), torch.zeros_like(x), x)
        
        # Process through different branches
        time_features = self.process_time_branch(x)
        freq_features = self.process_frequency_branch(x)
        filter_features = self.process_filter_bank(x)
        
        # NaN
        if time_features is not None and torch.isnan(time_features).any():
            print(f)
            time_features = torch.where(torch.isnan(time_features), torch.zeros_like(time_features), time_features)
        
        if freq_features is not None and torch.isnan(freq_features).any():
            print(f)
            freq_features = torch.where(torch.isnan(freq_features), torch.zeros_like(freq_features), freq_features)
        
        if filter_features is not None and torch.isnan(filter_features).any():
            print(f)
            filter_features = torch.where(torch.isnan(filter_features), torch.zeros_like(filter_features), filter_features)
        
        # Apply gating mechanism
        try:
            gated_features_list = self.apply_gating(time_features, freq_features, filter_features)
        except Exception as e:
            print(f)
            gated_features_list = [x]
        
        # Concatenate all features
        if gated_features_list:
            try:
                all_features = torch.cat(gated_features_list, dim=-1)
                # NaN
                if torch.isnan(all_features).any():
                    print(f)
                    all_features = torch.where(torch.isnan(all_features), torch.zeros_like(all_features), all_features)
            except Exception as e:
                print(f)
                all_features = x
        else:
            all_features = x
        
        # Final fusion
        try:
            output_features = self.fusion_layer(all_features)
            # NaN
            if torch.isnan(output_features).any():
                print(f)
                print(f)
                print(f)

                output_features = x
        except Exception as e:
            print(f)
            output_features = x
        
        # Prepare output dictionary
        result = {
            'features': output_features,
            'time_features': time_features,
            'freq_features': freq_features,
            'filter_features': filter_features
        }
        
        return result


def test_gated_multiscale_attention():
    """Test the GMA module"""
    
    # Test configuration
    config = {
        'scales': [8, 16, 32],
        'num_heads': 4,
        'time_branch': {
            'enabled': True,
            'filter_orders': [3, 5, 7]
        },
        'frequency_branch': {
            'enabled': True,
            'use_fft': True,
            'use_wavelet': True,
            'wavelet_levels': 3,
            'fft_bins': 32
        },
        'filter_bank': {
            'num_filters': 4,
            'filter_sizes': [3, 5, 7, 9],
            'dilation_rates': [1, 2, 4, 8]
        },
        'gating': {
            'depth': 2,
            'hidden_dim': 64
        },
        'attention': {
            'dropout': 0.1
        },
        'fusion': {
            'dropout': 0.1
        }
    }
    
    # Create test data
    batch_size = 2
    seq_len = 64
    features = 128
    
    x = torch.randn(batch_size, seq_len, features)
    
    # Create GMA module
    gma = GatedMultiscaleAttention(input_dim=features, config=config)
    
    # Forward pass
    result = gma(x, return_attention=True)
    
    print()
    print(f)
    print(ffeatures)
    
    if result['time_features'] is not None:
        print(ftime_features)
    if result['freq_features'] is not None:
        print(ffreq_features)
    if result['filter_features'] is not None:
        print(ffilter_features)
    
    print()


if __name__ == "__main__":
    test_gated_multiscale_attention() 