# -*- coding: utf-8 -*-
from typing import Dict, Optional, Union
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

from ..common.features_and_fusion import FeatureExtractorAndFusion
from ..modules.dynamic_segmentation import DynamicSegmentationModule
from ..modules.causal_graph import CausalGraphModuleLite

def _try_import(path_mod, name):
    try:
        mod = __import__(path_mod, fromlist=[name])
        return getattr(mod, name)
    except Exception:
        return None

MLPHead = _try_import("backbone.head", "MLPHead") or _try_import("model.backbone.head", "MLPHead")
AEBlock = _try_import("backbone.ae", "AEBlock") or _try_import("model.backbone.ae", "AEBlock")

class TCNBlock(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, dilation=1, dropout=0.1):
        super().__init__()
        # Use causal padding to keep output length unchanged
        # For causal convolution, only pad on the left
        self.padding = (kernel_size - 1) * dilation
        
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size, 
                              dilation=dilation, padding=self.padding)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size,
                              dilation=dilation, padding=self.padding)
        self.dropout = nn.Dropout(dropout)
        self.relu = nn.ReLU()
        self.norm1 = nn.BatchNorm1d(out_channels)
        self.norm2 = nn.BatchNorm1d(out_channels)
        
        # Residual connection - ensure dimension match
        self.downsample = nn.Conv1d(in_channels, out_channels, 1) if in_channels != out_channels else None
        
    def forward(self, x):
        # x: [B, C, T]
        residual = x
        
        out = self.conv1(x)
        out = self.norm1(out)
        out = self.relu(out)
        out = self.dropout(out)
        
        out = self.conv2(out)
        out = self.norm2(out)
        
        # Residual connection - ensure size match
        if self.downsample is not None:
            residual = self.downsample(residual)
        
        # Due to padding, output length may increase, need to truncate to original length
        if out.size(-1) > x.size(-1):
            out = out[:, :, :x.size(-1)]
            
        # Ensure residual and out length match
        if residual.size(-1) != out.size(-1):
            residual = residual[:, :, :out.size(-1)]
        
        out += residual
        
        out = self.relu(out)
        return out

class PositionalEncoding(nn.Module):
    """Positional encoding - provides position information for future time steps"""
    def __init__(self, d_model, max_len=100):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe.unsqueeze(0))
        
    def forward(self, x):
        # x: [B, T, D]
        return x + self.pe[:, :x.size(1)]

class ForecastHead(nn.Module):
    def __init__(self, in_dim, horizon=24, hidden=128, n_layers=2, dropout=0.1, 
                 use_tcn=True, use_multiscale=True, use_positional_encoding=True,
                 use_time_bias=True):
        super().__init__()
        self.horizon = horizon
        self.hidden = hidden
        self.use_tcn = use_tcn
        self.use_multiscale = use_multiscale
        self.use_positional_encoding = use_positional_encoding
        self.use_time_bias = use_time_bias
        
        # Input projection
        self.proj_in = nn.Linear(in_dim, hidden)
        
        # Learnable future query embedding
        self.horizon_queries = nn.Parameter(torch.randn(horizon, hidden))
        nn.init.normal_(self.horizon_queries, std=0.02)
        
        # Positional encoding
        if use_positional_encoding:
            self.pos_encoding = PositionalEncoding(hidden, max_len=horizon*2)
        
        # Time series processing network
        if use_tcn:
            # TCN structure - multi-scale dilated convolution
            self.tcn_layers = nn.ModuleList()
            in_channels = hidden
            for i in range(n_layers):
                dilation = 2 ** i
                self.tcn_layers.append(
                    TCNBlock(in_channels, hidden, kernel_size=3, 
                            dilation=dilation, dropout=dropout)
                )
                in_channels = hidden
        else:
            # GRU structure (improved version)
            self.rnn = nn.GRU(hidden, hidden, num_layers=n_layers, 
                            batch_first=True, dropout=dropout if n_layers > 1 else 0)
        
        # Time bias term - learn temporal drift patterns
        if use_time_bias:
            # Global time bias (learn overall offset)
            self.global_time_bias = nn.Parameter(torch.zeros(1))
            # Time step bias (learn offset for each prediction step)
            self.step_time_bias = nn.Parameter(torch.zeros(horizon))
            # Time trend bias (learn linear trend)
            self.trend_bias = nn.Parameter(torch.zeros(1))
            # Initialize bias terms
            nn.init.normal_(self.global_time_bias, std=0.01)
            nn.init.normal_(self.step_time_bias, std=0.01)
            nn.init.normal_(self.trend_bias, std=0.01)
        
        # Hierarchical multi-scale output head
        if use_multiscale:
            # Pyramid structure: 24 -> 12 -> 6 -> 3 -> 1
            self.scales = [24, 12, 6, 3, 1]
            self.scale_heads = nn.ModuleList()
            for scale in self.scales:
                # Each scale outputs a single value, not multiple values
                self.scale_heads.append(nn.Linear(hidden, 1))
        else:
            # Single output head
            self.out = nn.Linear(hidden, 1)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, feat):
        # feat: [B, D] global representation
        B, D = feat.shape
        
        # 1. Input projection
        h0 = self.proj_in(feat)  # [B, hidden]
        
        # 2. Future query embedding + conditioning
        Q = self.horizon_queries.unsqueeze(0).expand(B, -1, -1)  # [B, horizon, hidden]
        h = h0.unsqueeze(1) + Q  # Condition to each future step [B, horizon, hidden]
        
        # 3. Positional encoding
        if self.use_positional_encoding:
            h = self.pos_encoding(h)  # [B, horizon, hidden]
        
        # 4. Time series processing
        if self.use_tcn:
            # TCN processing - need to transpose to [B, hidden, horizon]
            h_tcn = h.transpose(1, 2)  # [B, hidden, horizon]
            for tcn_layer in self.tcn_layers:
                h_tcn = tcn_layer(h_tcn)
            h = h_tcn.transpose(1, 2)  # [B, horizon, hidden]
        else:
            # GRU processing
            h, _ = self.rnn(h)  # [B, horizon, hidden]
        
        h = self.dropout(h)
        
        # 5. Multi-scale output
        if self.use_multiscale:
            outputs = {}
            for i, (scale, head) in enumerate(zip(self.scales, self.scale_heads)):
                if scale == self.horizon:
                    # Full-scale output - directly output prediction value for each time step
                    out = head(h)  # [B, horizon, 1] -> squeeze to [B, horizon]
                    out = out.squeeze(-1) if out.dim() == 3 else out  # [B, horizon]
                else:
                    # Sub-scale output - uniform sampling
                    indices = torch.linspace(0, self.horizon-1, scale, dtype=torch.long, device=h.device)
                    h_sub = h[:, indices, :]  # [B, scale, hidden]
                    out = head(h_sub)  # [B, scale, 1] -> squeeze to [B, scale]
                    out = out.squeeze(-1) if out.dim() == 3 else out  # [B, scale]
                
                # Apply time bias terms
                if self.use_time_bias:
                    if scale == self.horizon:
                        # Full-scale: apply all bias terms
                        out = out + self.global_time_bias  # Global bias
                        out = out + self.step_time_bias.unsqueeze(0)  # Per-step bias
                        # Linear trend bias
                        trend = self.trend_bias * torch.arange(self.horizon, device=out.device).float()
                        out = out + trend.unsqueeze(0)
                    else:
                        # Sub-scale: only apply global bias and trend bias
                        out = out + self.global_time_bias
                        trend = self.trend_bias * torch.arange(scale, device=out.device).float()
                        out = out + trend.unsqueeze(0)
                
                outputs[f'forecast_scale_{scale}'] = out
            
            # Return main output (full-scale) and other scale outputs
            return outputs
        else:
            # Single output
            out = self.out(h).squeeze(-1)  # [B, horizon]
            
            # Apply time bias terms
            if self.use_time_bias:
                out = out + self.global_time_bias  # Global bias
                out = out + self.step_time_bias.unsqueeze(0)  # Per-step bias
                # Linear trend bias
                trend = self.trend_bias * torch.arange(self.horizon, device=out.device).float()
                out = out + trend.unsqueeze(0)
            
            return out

class ClassificationHead(nn.Module):
    """
    Output two heads:
    1. Regression prediction: [B, F]
    2. Missing prediction: [B, F]
    """
    def __init__(self, projection_dim, n_feature, out_len=1, dropout=0.1,
                 use_batch_norm=False, use_layer_norm=False):
        super().__init__()
        hidden = min(256, projection_dim * 2)

        # Regression prediction head
        self.reg_head = MLPHead(
            projection_dim,
            n_feature * out_len,
            hidden=hidden,
            dropout=dropout,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm
        )

        # Missing prediction head
        self.miss_head = MLPHead(
            projection_dim,
            n_feature * out_len,
            hidden=hidden,
            dropout=dropout,
            use_batch_norm=use_batch_norm,
            use_layer_norm=use_layer_norm
        )

        self.n_feature = n_feature
        self.out_len = out_len

    def forward(self, x):
        B = x.size(0)
        y_pred = self.reg_head(x).view(B, self.n_feature)       # [B, F]
        m_pred = torch.sigmoid(self.miss_head(x).view(B, self.n_feature))  # [B, F]
        return {"y_pred": y_pred, "m_pred": m_pred}

class MultiModalPredictor(nn.Module):
    """
    Classification/temporal prediction dual heads:
      encoding → EAMC → cross-modal fusion → causal graph → concat → heads
    """
    def __init__(self, config: Dict):
        super().__init__()
        self.config = config
        self.shared_dim = self.config.get("shared_dim", 128)
        
        # Regularization parameters
        self.dropout_rate = self.config.get("dropout_rate", 0.3)
        self.use_batch_norm = self.config.get("use_batch_norm", False)  # Changed to False to avoid batch size issues
        self.use_layer_norm = self.config.get("use_layer_norm", True)  # Enable layer norm by default

        # Ablation experiment control parameters (prefer reading from config, then from environment variables)
        import os
        self.use_dynseg = self.config.get('use_dynseg', self._get_bool_env('USE_DYNSEG', True))
        self.use_causal = self.config.get('use_causal', self._get_bool_env('USE_CAUSAL', True))
        self.use_nbc = self.config.get('use_nbc', self._get_bool_env('USE_NBC', True))
        self.frequency_branch_enabled = self.config.get('frequency_branch_enabled', self._get_bool_env('FREQUENCY_BRANCH_ENABLED', True))
        self.use_fft = self.config.get('use_fft', self._get_bool_env('USE_FFT', True))
        self.use_wavelet = self.config.get('use_wavelet', self._get_bool_env('USE_WAVELET', True))
        
        # Multi-modal control parameters
        self.use_multimodal = self.config.get('use_multimodal', self._get_bool_env('USE_MULTIMODAL', True))
        self.use_image = self.config.get('use_image', self._get_bool_env('USE_IMAGE', True))
        self.use_text = self.config.get('use_text', self._get_bool_env('USE_TEXT', True))
        self.use_static = self.config.get('use_static', self._get_bool_env('USE_STATIC', True))
        
        # Causal mode control parameters
        self.use_causal_frequency = self._get_bool_env('USE_CAUSAL_FREQUENCY', True)
        self.use_causal_pooling = self._get_bool_env('USE_CAUSAL_POOLING', True)
        
        # Update config to reflect environment variable settings
        config.update({
            'use_dynseg': self.use_dynseg,
            'use_causal': self.use_causal,
            'use_nbc': self.use_nbc,
            'frequency_branch_enabled': self.frequency_branch_enabled,
            'use_fft': self.use_fft,
            'use_wavelet': self.use_wavelet,
            'use_causal_frequency': self.use_causal_frequency,
            'use_causal_pooling': self.use_causal_pooling,
            'use_multimodal': self.use_multimodal,
            'use_image': self.use_image,
            'use_text': self.use_text,
            'use_static': self.use_static,
            'pooling_tau': self.config.get('pooling_tau', 5.0)
            
        })

        self.ff = FeatureExtractorAndFusion(config)
        
        # Decide whether to create modules based on ablation experiment settings
        if self.use_dynseg:
            self.seg_mod = DynamicSegmentationModule(shared_dim=self.shared_dim, config=config)
        else:
            self.seg_mod = None
            
        if self.use_causal:
            self.cg_mod = CausalGraphModuleLite(shared_dim=self.shared_dim, config=config)
        else:
            self.cg_mod = None

        # Add regularization layer after feature fusion
        final_dim = self.shared_dim * 2
        self.feature_dropout = nn.Dropout(self.dropout_rate)
        
        # Add feature projection layer to reduce dimension (using Pre-norm architecture and residual connection)
        projection_dim = final_dim // 2
        # Pre-norm architecture: LayerNorm before Linear
        self.feature_input_norm = nn.LayerNorm(final_dim)
        self.feature_projection_linear = nn.Linear(final_dim, projection_dim)
        self.feature_output_norm = nn.LayerNorm(projection_dim)
        # Residual connection: if dimensions don't match, need projection
        self.feature_residual_proj = nn.Linear(final_dim, projection_dim)
        self.feature_activation = nn.GELU()
        self.feature_dropout_after = nn.Dropout(self.dropout_rate)
        
        # Use more stable weight initialization (Kaiming initialization, similar to TimeSeriesEncoder)
        nn.init.kaiming_normal_(self.feature_projection_linear.weight, mode='fan_in', nonlinearity='relu')
        self.feature_projection_linear.weight.data *= 0.001  # Smaller initialization
        nn.init.constant_(self.feature_projection_linear.bias, 0.0)
        nn.init.kaiming_normal_(self.feature_residual_proj.weight, mode='fan_in', nonlinearity='relu')
        self.feature_residual_proj.weight.data *= 0.001  # Smaller initialization
        nn.init.constant_(self.feature_residual_proj.bias, 0.0)
        
        # Use improved MLPHead
        projection_dim = final_dim // 2
        self.classification_head = ClassificationHead(
            projection_dim,
            self.config.get("out_dim", 20),
            out_len=self.config.get("out_len", 1),
            dropout=self.dropout_rate,
            use_batch_norm=self.use_batch_norm,
            use_layer_norm=self.use_layer_norm
        )
        
        self.forecast_head = ForecastHead(
            projection_dim,
            horizon=self.config.get("out_len", 24),   # Prediction steps
            hidden=min(256, projection_dim * 2),
            n_layers=2,
            dropout=self.dropout_rate,
            use_tcn=self.config.get("forecast_use_tcn", True),  # Use TCN to replace GRU
            use_multiscale=self.config.get("forecast_use_multiscale", True),  # Use multi-scale output
            use_positional_encoding=self.config.get("forecast_use_positional_encoding", True),  # Use positional encoding
            use_time_bias=self.config.get("forecast_use_time_bias", True)  # Enable time bias terms
        )


    def _get_bool_env(self, env_var: str, default: bool = True) -> bool:
        """Read boolean value from environment variable"""
        import os
        value = os.environ.get(env_var, '')
        if value.lower() in ('true', '1', 'yes', 'on'):
            return True
        elif value.lower() in ('false', '0', 'no', 'off'):
            return False
        else:
            return default

    def forward(self,
                images: torch.Tensor,
                text_item: Optional[Dict[str, torch.Tensor]],
                packed_ts: Union[Dict[str, torch.Tensor], torch.Tensor],
                static_data: torch.Tensor,
                task: str = None,
                return_ae: bool = False) -> Dict[str, torch.Tensor]:

        img_tok, txt_tok, ts_feat, static_vec = self.ff.extract_features(images, text_item, packed_ts, static_data)
        MULTI_PATH = False
        if not MULTI_PATH:
            # Handle multi-path case
            if ts_feat.dim() == 4:  # [B, N, T, D]
                ts_feat = ts_feat[:, 0]  # [B, T, D]
            elif ts_feat.dim() == 3:  # [B, T, D]
                pass  # Already correct shape
            else:
                # If shape incorrect, try reshaping
                if ts_feat.dim() == 2:  # [B, D] -> need to expand
                    ts_feat = ts_feat.unsqueeze(1)  # [B, 1, D] - expand dimension
            
            if static_vec.dim() == 3:  # [B, N, D]
                static_vec = static_vec[:, 0]  # [B, D]
            elif static_vec.dim() == 2:  # [B, D]
                pass  # Already correct shape
            
            img_tok = img_tok# .mean(dim=1)
            txt_tok = txt_tok# .mean(dim=1)
        
        # Ensure ts_feat is 3D [B, T, D]
        if ts_feat.dim() == 5:
            # [B, N, M, T, D] -> [B, T, D] (take first path and first modality)
            ts_feat = ts_feat[:, 0, 0]
        elif ts_feat.dim() == 4:
            # [B, N, T, D] -> [B, T, D] (take first path)
            ts_feat = ts_feat[:, 0]
        elif ts_feat.dim() == 3:
            # [B, T, D] - already correct shape
            pass
        elif ts_feat.dim() == 2:
            # [B, D] -> [B, 1, D]
            ts_feat = ts_feat.unsqueeze(1)
        else:
            raise ValueError(f"Unexpected ts_feat shape: {ts_feat.shape}, expected 2-5 dims")

        if isinstance(packed_ts, dict) and ("seq_lengths" in packed_ts) and packed_ts["seq_lengths"].dim() >= 1:
            ts_lengths = packed_ts["seq_lengths"]
            if ts_lengths.dim() == 2:
                ts_lengths = ts_lengths.min(dim=1)[0]
        else:
            ts_lengths = torch.full((ts_feat.size(0),), ts_feat.size(1), device=ts_feat.device, dtype=torch.long)

        # Dynamic segmentation processing
        if self.use_dynseg and self.seg_mod is not None:
            seg_tokens, seg_seq, eamc_info = self.seg_mod(ts_feat, ts_lengths)
        else:
            # Directly use original time series features
            seg_tokens = ts_feat
            seg_seq = ts_feat.unsqueeze(2) if ts_feat.dim() == 3 else ts_feat  # [B,T,1,D] or [B,T,D]
            eamc_info = {"use_dynseg": False, "use_nbc": False}
        
        # Cross-modal fusion
        fused_global = self.ff.cross_modal_fusion(img_tok, txt_tok, seg_tokens, static_vec)
        
        # Causal graph processing
        if self.use_causal and self.cg_mod is not None:
            # CausalGraphModuleLite only accepts seg_seq, img_tok, txt_tok
            graph_nodes, graph_info = self.cg_mod(
                seg_seq, img_tok, txt_tok
            )
            graph_pool = graph_nodes.mean(dim=1)
        else:
            # Use simple pooling as replacement
            if seg_tokens.dim() == 3:  # [B,T,D]
                graph_pool = seg_tokens.mean(dim=1)  # [B,D]
            else:  # [B,D]
                graph_pool = seg_tokens
            graph_info = {"use_causal": False}
        
        # Feature fusion
        feat = torch.cat([fused_global, graph_pool], dim=-1)
        
        # Input stability check: clipping and normalization
        if torch.isnan(feat).any() or torch.isinf(feat).any():
            print(f"[WARNING] feature_projection input contains NaN/Inf, replacing with zeros")
            feat = torch.where(torch.isnan(feat) | torch.isinf(feat), torch.zeros_like(feat), feat)
        feat = torch.clamp(feat, min=-10.0, max=10.0)
        
        # Apply regularization (Pre-norm architecture + residual connection)
        feat = self.feature_dropout(feat)
        
        # Pre-norm architecture: normalize input features first
        feat_norm = self.feature_input_norm(feat)
        feat_norm = torch.clamp(feat_norm, min=-5.0, max=5.0)  # Intermediate clipping
        
        # Apply feature projection
        feat_proj = self.feature_projection_linear(feat_norm)
        feat_proj = torch.clamp(feat_proj, min=-10.0, max=10.0)  # Intermediate clipping
        
        # Temporarily disable residual connection to stabilize training (similar to TimeSeriesEncoder)
        use_residual = False
        if use_residual:
            # Residual connection
            feat_residual = self.feature_residual_proj(feat_norm)
            feat_residual = torch.clamp(feat_residual, min=-10.0, max=10.0)
            feat_proj = feat_proj + feat_residual * 0.1  # Use very small residual weight
        # If residual disabled, directly use projection output
        
        # Output normalization
        feat_proj = self.feature_output_norm(feat_proj)
        feat_proj = torch.clamp(feat_proj, min=-5.0, max=5.0)  # Final clipping
        
        # NaN/Inf check
        if torch.isnan(feat_proj).any() or torch.isinf(feat_proj).any():
            print(f"[WARNING] feature_projection output contains NaN/Inf, replacing with zeros")
            feat_proj = torch.where(torch.isnan(feat_proj) | torch.isinf(feat_proj), 
                                   torch.zeros_like(feat_proj), feat_proj)
        
        # Activation function and dropout
        feat = self.feature_activation(feat_proj)
        feat = self.feature_dropout_after(feat)
        
        # Update info dictionary
        eamc_info["use_dynseg"] = self.use_dynseg
        eamc_info["use_nbc"] = self.use_nbc
        graph_info["use_causal"] = self.use_causal

        # Save structured information to model attributes for training script to read
        self.last_eamc_info = eamc_info
        self.last_graph_info = graph_info
        # Save features for MultiTaskWrapper to use
        self.last_features = feat

        # Decide which output to return as logits based on task parameter
        if task == 'forecasting' or task == 'regression':
            forecast_output = self.forecast_head(feat)
            # Handle multi-scale output
            if isinstance(forecast_output, dict):
                logits = forecast_output['forecast_scale_24']  # Main output (full-scale)
            else:
                logits = forecast_output
        else:  # Default use classification head
            logits = self.classification_head(feat)
        
        # Calculate alignment loss (clip_loss) - based on multimodal feature alignment
        clip_loss = self._compute_alignment_loss(img_tok, txt_tok, seg_tokens, static_vec)
        
        # Prepare return result
        result = {
            "logits": logits,
            "features": feat,  # Add feature output for MultiTaskWrapper to use
            # "logits_cls": self.classification_head(feat),
            "clip_loss": clip_loss,
            "eamc_info": eamc_info,
            "graph_info": graph_info
        }
        
        # Add prediction output
        forecast_output = self.forecast_head(feat)
        if isinstance(forecast_output, dict):
            # Multi-scale output
            result["forecast"] = forecast_output['forecast_scale_24']  # Main output
            result.update(forecast_output)  # Add all scale outputs
        else:
            # Single output
            result["forecast"] = forecast_output
        
        return result
    
    def _compute_alignment_loss(self, img_tok, txt_tok, ts_tok, static_vec):
        """Calculate alignment loss for multimodal features"""
        device = next(self.parameters()).device
        alignment_loss = torch.tensor(0.0, device=device)
        
        # Collect all available modality features
        modalities = []
        modality_names = []
        
        if img_tok is not None and not torch.allclose(img_tok, torch.zeros_like(img_tok)):
            # Image features: take mean to get global representation
            if img_tok.dim() == 3:  # [B, N, D]
                img_global = img_tok.mean(dim=1)  # [B, D]
            elif img_tok.dim() == 2:  # [B, D]
                img_global = img_tok
            else:  # Other dimensions, flatten to 2D
                img_global = img_tok.view(img_tok.size(0), -1)
            modalities.append(img_global)
            modality_names.append("image")
        
        if txt_tok is not None and not torch.allclose(txt_tok, torch.zeros_like(txt_tok)):
            # Text features: take mean to get global representation
            if txt_tok.dim() == 3:  # [B, N, D]
                txt_global = txt_tok.mean(dim=1)  # [B, D]
            elif txt_tok.dim() == 2:  # [B, D]
                txt_global = txt_tok
            else:  # Other dimensions, flatten to 2D
                txt_global = txt_tok.view(txt_tok.size(0), -1)
            modalities.append(txt_global)
            modality_names.append("text")
        
        if ts_tok is not None and not torch.allclose(ts_tok, torch.zeros_like(ts_tok)):
            # Time series features: take mean to get global representation
            if ts_tok.dim() == 3:  # [B, N, D]
                ts_global = ts_tok.mean(dim=1)  # [B, D]
            elif ts_tok.dim() == 2:  # [B, D]
                ts_global = ts_tok
            else:  # Other dimensions, flatten to 2D
                ts_global = ts_tok.view(ts_tok.size(0), -1)
            modalities.append(ts_global)
            modality_names.append("timeseries")
        
        if static_vec is not None and not torch.allclose(static_vec, torch.zeros_like(static_vec)):
            # Static features: ensure 2D
            if static_vec.dim() == 2:  # [B, D]
                modalities.append(static_vec)
            else:  # Other dimensions, flatten to 2D
                static_global = static_vec.view(static_vec.size(0), -1)
                modalities.append(static_global)
            modality_names.append("static")
        
        # If number of modalities less than 2, cannot calculate alignment loss
        if len(modalities) < 2:
            return alignment_loss
        
        # Ensure all features are 2D
        for i, (feat, name) in enumerate(zip(modalities, modality_names)):
            if feat.dim() != 2:
                modalities[i] = feat.view(feat.size(0), -1)
        
        # Unify dimensions of all features (use minimum feature dimension)
        if len(modalities) > 0:
            min_dim = min(feat.size(1) for feat in modalities)
            for i, feat in enumerate(modalities):
                if feat.size(1) != min_dim:
                    # If feature dimensions differ, directly truncate to unified dimension
                    modalities[i] = feat[:, :min_dim]  # Truncate
        
        # Calculate contrastive learning loss between modalities
        batch_size = modalities[0].size(0)
        temperature = 0.07  # Temperature parameter
        
        # Calculate similarity between all modality pairs
        total_loss = 0.0
        pair_count = 0
        
        for i in range(len(modalities)):
            for j in range(i + 1, len(modalities)):
                feat_i = modalities[i]  # [B, D]
                feat_j = modalities[j]  # [B, D]
                
                # Calculate similarity matrix
                # Normalize features
                feat_i_norm = F.normalize(feat_i, p=2, dim=1)  # [B, D]
                feat_j_norm = F.normalize(feat_j, p=2, dim=1)  # [B, D]
                
                # Calculate similarity matrix [B, B]
                sim_matrix = torch.mm(feat_i_norm, feat_j_norm.t()) / temperature
                
                # Create positive sample labels (diagonal is 1)
                labels = torch.arange(batch_size, device=device)
                
                # Calculate contrastive loss (both directions)
                loss_i2j = F.cross_entropy(sim_matrix, labels)
                loss_j2i = F.cross_entropy(sim_matrix.t(), labels)
                
                # Average losses from both directions
                pair_loss = (loss_i2j + loss_j2i) / 2.0
                total_loss += pair_loss
                pair_count += 1
        
        # Average losses from all modality pairs
        if pair_count > 0:
            alignment_loss = total_loss / pair_count
        
        return alignment_loss
