import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import math
from typing import Optional, Tuple, Dict, List
import warnings
warnings.filterwarnings('ignore')

from torch.cuda.amp import autocast

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        self.d_model = d_model
        
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, seq_len: int) -> torch.Tensor:
        return self.pe[:seq_len].transpose(0, 1).unsqueeze(2)

class EmbeddingProcessModule(nn.Module):
    def __init__(self, in_channels: int, d_model: int, d_ff_emb: int):
        super().__init__()
        self.d_model = d_model
        self.d_ff_emb = d_ff_emb
        
        self.feature_ff = nn.Sequential(
            nn.Linear(in_channels, d_ff_emb, bias=True),
            nn.ReLU(),
            nn.Linear(d_ff_emb, d_model, bias=True)
        )
    
    def forward(self, x_h: torch.Tensor, x_w: torch.Tensor, x_d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        x_h = x_h.permute(0, 3, 1, 2)
        x_h_emb = self.feature_ff(x_h)
        
        x_w = x_w.permute(0, 3, 1, 2)
        x_w_emb = self.feature_ff(x_w)
        
        x_d = x_d.permute(0, 3, 1, 2)
        x_d_emb = self.feature_ff(x_d)
        
        return x_h_emb, x_w_emb, x_d_emb

class EnhancedMultiScaleTemporalConv(nn.Module):
    def __init__(self, d_model: int):
        super().__init__()
        self.d_model = d_model
        self.time_scales = [1, 3, 5, 7, 9, 11, 13, 15]
        self.num_scales = len(self.time_scales)
        
        self.multi_scale_conv = nn.ModuleList()
        
        for scale in self.time_scales:
            conv = nn.Sequential(
                nn.Conv1d(
                    d_model, d_model,
                    kernel_size=scale,
                    padding=scale//2,
                    groups=1
                ),
                nn.ReLU()
            )
            self.multi_scale_conv.append(conv)
        
        self.fusion = nn.Sequential(
            nn.Linear(d_model * len(self.time_scales), d_model * 2),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(d_model * 2, d_model)
        )
        
        self.residual_weight = nn.Parameter(torch.tensor([0.5]))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, seq_len, N, d_model = x.shape
        
        residual = x
        
        x_reshaped = x.permute(0, 2, 3, 1).contiguous()
        x_reshaped = x_reshaped.view(B * N, d_model, seq_len)
        
        scale_features = []
        for conv in self.multi_scale_conv:
            feat = conv(x_reshaped)
            
            if feat.size(-1) != seq_len:
                feat = F.interpolate(feat, size=seq_len, mode='linear', align_corners=False)
            
            feat = feat.view(B, N, d_model, seq_len)
            feat = feat.permute(0, 3, 1, 2)
            scale_features.append(feat)
        
        multi_scale = torch.cat(scale_features, dim=-1)
        
        fused = self.fusion(multi_scale)
        
        output = self.residual_weight * fused + (1 - self.residual_weight) * residual
        
        return output

class ImprovedCausalEstimator(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, rank: int = None):
        super().__init__()
        self.num_nodes = num_nodes
        self.d_model = d_model
        
        self.rank = rank if rank is not None else min(32, num_nodes // 4)
        
        self.node_embedding = nn.Parameter(torch.randn(num_nodes, d_model // 4))
        
        self.source_transform = nn.Linear(d_model + d_model // 4, self.rank, bias=True)
        self.target_transform = nn.Linear(d_model + d_model // 4, self.rank, bias=True)
        
        self.time_delay_encoder = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, self.rank)
        )
        
        self.node_bias = nn.Parameter(torch.randn(num_nodes, num_nodes) * 0.01)
        
        self.register_buffer('self_loop_mask', 
                           torch.ones(num_nodes, num_nodes) - torch.eye(num_nodes))
    
    def forward(self, node_features: torch.Tensor) -> torch.Tensor:
        B, seq_len, N, d_model = node_features.shape
        device = node_features.device
        
        if seq_len > 1:
            time_weights = torch.linspace(0.5, 1.0, seq_len, device=device)
            time_weights = time_weights / time_weights.sum()
            time_weights = time_weights.view(1, seq_len, 1, 1)
            weighted_features = node_features * time_weights
            avg_features = weighted_features.sum(dim=1)
        else:
            avg_features = node_features.squeeze(1)
        
        node_emb = self.node_embedding.unsqueeze(0).expand(B, -1, -1)
        
        node_repr = torch.cat([avg_features, node_emb], dim=-1)
        
        source_repr = self.source_transform(node_repr)
        target_repr = self.target_transform(node_repr)
        
        time_delay_features = self.time_delay_encoder(avg_features)
        source_repr = source_repr + 0.1 * time_delay_features
        
        causal_matrix = torch.bmm(source_repr, target_repr.transpose(1, 2))
        
        causal_matrix = causal_matrix + self.node_bias.unsqueeze(0)
        
        causal_matrix = torch.sigmoid(causal_matrix)
        
        causal_matrix = causal_matrix * self.self_loop_mask.to(device)
        
        return causal_matrix

class BatchedTemporalCausalEstimator(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, num_time_segments: int, rank: int = None):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_time_segments = num_time_segments
        self.d_model = d_model
        
        self.segment_estimators = nn.ModuleList([
            ImprovedCausalEstimator(d_model + 32, num_nodes, rank)
            for _ in range(num_time_segments)
        ])
        
        self.time_embedding = nn.Embedding(num_time_segments, 32)
        
    def forward(self, all_segment_features: List[torch.Tensor]) -> torch.Tensor:
        stacked_features = torch.stack(all_segment_features, dim=1)
        B, T, N, d_model = stacked_features.shape
        device = stacked_features.device
        
        time_indices = torch.arange(T, device=device)
        time_embs = self.time_embedding(time_indices)
        
        causal_matrices_list = []
        
        for t in range(T):
            segment_features = stacked_features[:, t, :, :]
            
            time_emb = time_embs[t:t+1].expand(B, N, -1)
            
            features_with_time = torch.cat([
                segment_features, 
                time_emb
            ], dim=-1)
            
            features_with_time = features_with_time.unsqueeze(1)
            
            if t < len(self.segment_estimators):
                causal_matrix = self.segment_estimators[t](features_with_time)
            else:
                causal_matrix = self.segment_estimators[-1](features_with_time)
            
            causal_matrices_list.append(causal_matrix)
        
        causal_matrices = torch.stack(causal_matrices_list, dim=1)
        
        return causal_matrices

class ImprovedBayesianCausalFusion(nn.Module):
    def __init__(self, num_nodes: int, num_time_segments: int, d_model: int = 64, 
                 adj_matrix: Optional[torch.Tensor] = None):
        super().__init__()
        self.num_nodes = num_nodes
        self.num_time_segments = num_time_segments
        self.d_model = d_model
        
        self.node_embedding = nn.Parameter(torch.randn(num_nodes, d_model))
        
        self.segment_embedding = nn.Parameter(torch.randn(num_time_segments, d_model))
        
        self.segment_attention = nn.MultiheadAttention(
            d_model, num_heads=4, dropout=0.1, batch_first=True
        )
        
        self.matrix_encoder = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 8, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d(1)
        )
        
        self.weight_predictor = nn.Sequential(
            nn.Linear(8 + d_model, 32),
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Sigmoid()
        )
        
        if adj_matrix is not None:
            self.prior_causal = nn.Parameter(adj_matrix.float())
        else:
            self.prior_causal = nn.Parameter(torch.zeros(num_nodes, num_nodes))
        
        self.prior_confidence = nn.Parameter(torch.tensor([0.3]))
        
        self.consistency_threshold = 0.01
    
    def forward(self, segment_causal_matrices: torch.Tensor) -> torch.Tensor:
        B, T, N, _ = segment_causal_matrices.shape
        device = segment_causal_matrices.device
        
        matrix_features = []
        for t in range(T):
            matrix_t = segment_causal_matrices[:, t, :, :].unsqueeze(1)
            features_t = self.matrix_encoder(matrix_t)
            features_t = features_t.squeeze(-1).squeeze(-1)
            matrix_features.append(features_t)
        
        matrix_features = torch.stack(matrix_features, dim=1)
        
        segment_emb = self.segment_embedding.unsqueeze(0).expand(B, -1, -1)
        
        combined_features = torch.cat([
            matrix_features,
            segment_emb[:, :T, :]
        ], dim=-1)
        
        node_emb_avg = self.node_embedding.mean(dim=0, keepdim=True)
        query = node_emb_avg.unsqueeze(0).expand(B, 1, -1)
        
        _, attention_weights = self.segment_attention(
            query, segment_emb[:, :T, :], segment_emb[:, :T, :],
            need_weights=True, average_attn_weights=True
        )
        
        segment_weights = attention_weights.squeeze(1)
        segment_weights = F.softmax(segment_weights, dim=1)
        
        weights_expanded = segment_weights.unsqueeze(-1).unsqueeze(-1)
        data_driven_causal = torch.sum(segment_causal_matrices * weights_expanded, dim=1)
        
        prior_weight = torch.sigmoid(self.prior_confidence)
        prior_matrix = torch.sigmoid(self.prior_causal).unsqueeze(0).expand(B, -1, -1)
        
        fused_causal = (1 - prior_weight) * data_driven_causal + prior_weight * prior_matrix
        
        fused_causal = self._apply_consistency_constraints(fused_causal)
        
        return fused_causal
    
    def _apply_consistency_constraints(self, causal_matrix: torch.Tensor) -> torch.Tensor:
        B, N, _ = causal_matrix.shape
        device = causal_matrix.device
        
        eye_mask = torch.eye(N, device=device).unsqueeze(0).expand(B, -1, -1)
        causal_matrix = causal_matrix * (1 - eye_mask)
        
        threshold = self.consistency_threshold
        causal_matrix = F.relu(causal_matrix - threshold) + threshold * torch.sigmoid((causal_matrix - threshold) * 10)
        
        causal_matrix = torch.clamp(causal_matrix, 0.0, 1.0)
        
        return causal_matrix

class InterventionValidator(nn.Module):
    def __init__(self, num_nodes: int, d_model: int):
        super().__init__()
        self.num_nodes = num_nodes
        self.d_model = d_model
        
        self.register_buffer('node_pool', torch.arange(num_nodes))
        self.register_buffer('current_round_used', torch.zeros(num_nodes, dtype=torch.bool))
        self.current_round_idx = 0
        
        self.intervention_generator = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1),
            nn.Tanh()
        )
        
        self.validity_scorer = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 1),
            nn.Sigmoid()
        )
        
    def select_intervention_nodes(self, batch_size: int, num_interventions: int = 1) -> torch.Tensor:
        device = self.node_pool.device
        
        available_mask = ~self.current_round_used
        available_nodes = self.node_pool[available_mask]
        
        if len(available_nodes) < num_interventions:
            self.current_round_used.fill_(False)
            available_nodes = self.node_pool.clone()
            self.current_round_idx += 1
        
        intervention_nodes = torch.zeros(batch_size, num_interventions, dtype=torch.long, device=device)
        
        for b in range(batch_size):
            perm = torch.randperm(len(available_nodes), device=device)[:num_interventions]
            selected = available_nodes[perm]
            intervention_nodes[b] = selected
            
            if b == 0:
                for node in selected:
                    self.current_round_used[node] = True
        
        return intervention_nodes
    
    def compute_intervention_effect(self, 
                                   original_pred: torch.Tensor, 
                                   intervened_pred: torch.Tensor,
                                   intervention_nodes: torch.Tensor,
                                   causal_matrix: torch.Tensor) -> torch.Tensor:
        B, N, T = original_pred.shape
        device = original_pred.device
        
        pred_change = torch.abs(intervened_pred - original_pred)
        pred_change_mean = pred_change.mean(dim=2)
        
        validity_score = torch.zeros(B, N, N, device=device)
        
        for k in range(intervention_nodes.shape[1]):
            int_nodes = intervention_nodes[:, k]
            
            int_mask = F.one_hot(int_nodes, num_classes=N).float()
            
            batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, N)
            source_indices = int_nodes.unsqueeze(1).expand(B, N)
            
            causal_strength = causal_matrix.gather(1, source_indices.unsqueeze(2).expand(B, N, N))[:, 0, :]
            
            expected_change = causal_strength
            
            actual_change = pred_change_mean
            
            consistency = 1.0 - torch.abs(expected_change - actual_change / (actual_change.max(dim=1, keepdim=True)[0] + 1e-8))
            
            validity_score += int_mask.unsqueeze(2) * consistency.unsqueeze(1)
        
        return validity_score

class CausalPropagationSimulator(nn.Module):
    def __init__(self, d_model: int, num_time_segments: int, use_gru: bool = True):
        super().__init__()
        self.d_model = d_model
        self.use_gru = use_gru
        
        self.direct_effect = nn.Sequential(
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1)
        )
        
        if self.use_gru:
            self.propagation_gru = nn.GRUCell(d_model, d_model)
            
            self.gru_to_effect = nn.Sequential(
                nn.Linear(d_model, d_model // 2),
                nn.ReLU(),
                nn.Linear(d_model // 2, 1)
            )
        
        self.intervention_encoder = nn.Sequential(
            nn.Linear(1, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, d_model)
        )
        
    def forward(self, node_features: torch.Tensor, causal_matrix: torch.Tensor, 
                intervention_nodes: Optional[torch.Tensor] = None, 
                intervention_values: Optional[torch.Tensor] = None) -> torch.Tensor:
        B, seq_len, N, d_model = node_features.shape
        device = node_features.device
        
        effects = torch.zeros(B, seq_len, N, 1, device=device)
        
        A = causal_matrix.clone()
        
        if intervention_nodes is not None and intervention_values is not None:
            K = intervention_nodes.shape[1]
            for k in range(K):
                nodes = intervention_nodes[:, k]
                mask = torch.ones_like(A)
                batch_idx = torch.arange(B, device=device)
                mask[batch_idx, nodes, :] = 0
                A = A * mask
        
        if self.use_gru:
            hidden = torch.zeros(B * N, d_model, device=device)
            
            for t in range(seq_len):
                features_t = node_features[:, t, :, :].clone()
                
                if intervention_nodes is not None and intervention_values is not None:
                    K = intervention_nodes.shape[1]
                    for k in range(K):
                        nodes = intervention_nodes[:, k]
                        values = intervention_values[:, k:k+1]
                        encoded = self.intervention_encoder(values)
                        batch_idx = torch.arange(B, device=device)
                        features_t[batch_idx, nodes] = encoded
                
                causal_influence = torch.bmm(A, features_t)
                
                causal_input = causal_influence.view(B * N, d_model)
                
                hidden = self.propagation_gru(causal_input, hidden)
                
                hidden_reshaped = hidden.view(B, N, d_model)
                effects[:, t, :, :] = self.gru_to_effect(hidden_reshaped)
                
        else:
            for t in range(seq_len):
                features_t = node_features[:, t, :, :].clone()
                
                if intervention_nodes is not None and intervention_values is not None:
                    K = intervention_nodes.shape[1]
                    for k in range(K):
                        nodes = intervention_nodes[:, k]
                        values = intervention_values[:, k:k+1]
                        encoded = self.intervention_encoder(values)
                        batch_idx = torch.arange(B, device=device)
                        features_t[batch_idx, nodes] = encoded
                
                causal_influence = torch.bmm(A, features_t)
                effects[:, t, :, :] = self.direct_effect(causal_influence)
        
        return effects

def improved_causal_loss_with_intervention(causal_matrices: torch.Tensor, 
                                          validity_scores: Optional[torch.Tensor] = None,
                                          adj_matrix: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]:
    B, T, N, _ = causal_matrices.shape
    device = causal_matrices.device
    losses = {}
    
    base_sparsity = torch.mean(torch.abs(causal_matrices))
    
    if validity_scores is not None:
        weight_factor = 2.0 - torch.clamp(validity_scores, 0, 1).mean()
        sparsity_loss = base_sparsity * weight_factor
    else:
        sparsity_loss = base_sparsity
    
    if validity_scores is not None:
        validity_loss = -torch.mean(validity_scores)
    else:
        validity_loss = torch.tensor(0.0, device=device)
    
    min_connections = 1
    out_degree = causal_matrices.mean(dim=(0, 1)).sum(dim=1)
    connection_penalty = F.relu(min_connections - out_degree).mean()
    
    if T > 1:
        temporal_diff = torch.diff(causal_matrices, dim=1)
        temporal_consistency_loss = torch.mean(torch.abs(temporal_diff))
    else:
        temporal_consistency_loss = torch.tensor(0.0, device=device)
    
    avg_causal = causal_matrices.mean(dim=(0, 1))
    eye_mask = 1 - torch.eye(N, device=device)
    avg_causal = avg_causal * eye_mask
    
    A_norm = avg_causal / (torch.max(avg_causal) + 1e-8)
    A2 = torch.matmul(A_norm, A_norm)
    dag_loss = torch.mean(torch.diagonal(A2))
    
    losses['sparsity'] = sparsity_loss
    losses['validity'] = validity_loss
    losses['min_connections'] = connection_penalty
    losses['distribution'] = torch.tensor(0.0, device=device)
    losses['temporal_consistency'] = temporal_consistency_loss
    losses['dag'] = dag_loss
    
    return losses

class TECausGAT(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, num_time_segments: int, head_s: int, 
                dropout: float = 0.1, causal_threshold: float = 0.1, adj_matrix: Optional[torch.Tensor] = None,
                rank: int = None, use_gru: bool = True):
        super().__init__()
        self.d_model = d_model
        self.num_nodes = num_nodes
        self.num_time_segments = num_time_segments
        self.head_s = head_s
        self.dropout = dropout
        
        self.causal_gate = nn.Parameter(torch.tensor(0.0))
        
        assert d_model % head_s == 0, f"d_model {d_model} must be divisible by head_s {head_s}"
        self.d_k = d_model // head_s
        
        self.multi_scale_temporal = EnhancedMultiScaleTemporalConv(d_model)
        
        self.causal_estimator = BatchedTemporalCausalEstimator(
            d_model, num_nodes, num_time_segments, rank
        )
        
        self.causal_fusion = ImprovedBayesianCausalFusion(
            num_nodes, num_time_segments, d_model, adj_matrix
        )
        
        self.propagation_simulator = CausalPropagationSimulator(
            d_model, num_time_segments, use_gru=use_gru
        )
        
        self.intervention_validator = InterventionValidator(num_nodes, d_model)
        
        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, d_model, bias=False)
        self.v_proj = nn.Linear(d_model, d_model, bias=False)
        self.out_proj = nn.Linear(d_model, d_model, bias=False)
        
        self.causal_modulation = nn.Sequential(
            nn.Linear(d_model * 2, d_model),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_model, d_model // 2),
            nn.ReLU(),
            nn.Linear(d_model // 2, 1),
            nn.Sigmoid()
        )
        
        self.dropout_layer = nn.Dropout(dropout)
        
    def _causal_modulated_attention(self, x: torch.Tensor, causal_matrix: torch.Tensor, 
                                   causal_effects: torch.Tensor) -> torch.Tensor:
        B, seq_len, N, d_model = x.shape
        device = x.device
        
        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)
        
        alpha = torch.sigmoid(self.causal_gate)
        
        if causal_effects is not None:
            combined_features = torch.cat([x, causal_effects.expand(-1, -1, -1, d_model)], dim=-1)
            modulation_strength = self.causal_modulation(combined_features)
            
            V = V * (1 + alpha * 2.0 * modulation_strength * causal_effects.expand(-1, -1, -1, d_model))
        
        Q = Q.view(B, seq_len, N, self.head_s, self.d_k).permute(0, 3, 1, 2, 4)
        K = K.view(B, seq_len, N, self.head_s, self.d_k).permute(0, 3, 1, 2, 4)
        V = V.view(B, seq_len, N, self.head_s, self.d_k).permute(0, 3, 1, 2, 4)
        
        scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        modulation_matrix = causal_matrix.unsqueeze(1).unsqueeze(2).expand(B, self.head_s, seq_len, N, N)
        enhanced_scores = scores + alpha * 5.0 * modulation_matrix
        
        attn_weights = F.softmax(enhanced_scores, dim=-1)
        attn_weights = self.dropout_layer(attn_weights)
        
        context = torch.matmul(attn_weights, V)
        
        context = context.permute(0, 2, 3, 1, 4)
        context = context.contiguous().view(B, seq_len, N, d_model)
        
        output = self.out_proj(context)
        
        return output
        
    def causal_regularization_loss(self, causal_matrices: torch.Tensor, 
                                  validity_scores: Optional[torch.Tensor] = None) -> torch.Tensor:
        loss_dict = improved_causal_loss_with_intervention(causal_matrices, validity_scores)
        
        if validity_scores is not None:
            avg_validity = validity_scores.mean().item()
            sparsity_weight = 0.01 * (2.0 - avg_validity)
        else:
            sparsity_weight = 0.01
        
        total_loss = (
            sparsity_weight * loss_dict['sparsity'] +
            0.5 * loss_dict['validity'] +
            0.1 * loss_dict['min_connections'] +
            0.05 * loss_dict['distribution'] +
            0.01 * loss_dict['temporal_consistency'] +
            0.1 * loss_dict['dag']
        )
        
        return total_loss
    
    def forward(self, x: torch.Tensor, time_indices: torch.Tensor, 
            perform_intervention: bool = False, 
            num_interventions: int = 2) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        B, seq_len, N, original_d_model = x.shape
        device = x.device
        
        assert original_d_model == self.d_model, f"输入维度错误: {original_d_model} != {self.d_model}"
        
        multi_scale_features = self.multi_scale_temporal(x)
        
        segment_features_list = []
        
        for t in range(self.num_time_segments):
            if time_indices is not None and time_indices.numel() > 0:
                segment_mask = (time_indices == t).float()
            else:
                segment_start = t * seq_len // self.num_time_segments
                segment_end = (t + 1) * seq_len // self.num_time_segments
                segment_mask = torch.zeros(B, seq_len, device=device)
                segment_mask[:, segment_start:segment_end] = 1.0
            
            if segment_mask.sum() > 0:
                weights = segment_mask.unsqueeze(-1).unsqueeze(-1)
                weighted_features = multi_scale_features * weights
                segment_features = weighted_features.sum(dim=1) / (segment_mask.sum(dim=1, keepdim=True).unsqueeze(-1) + 1e-8)
            else:
                segment_features = multi_scale_features.mean(dim=1)
            
            segment_features_list.append(segment_features)
        
        causal_matrices = self.causal_estimator(segment_features_list)
        
        fused_causal_matrix = self.causal_fusion(causal_matrices)
        
        causal_effects = self.propagation_simulator(multi_scale_features, fused_causal_matrix)
        
        validity_scores = None
        if perform_intervention and self.training:
            intervention_nodes = self.intervention_validator.select_intervention_nodes(
                B, num_interventions=num_interventions
            )
            
            avg_features = multi_scale_features.mean(dim=(1, 2))
            intervention_values = self.intervention_validator.intervention_generator(avg_features)
            intervention_values = intervention_values.expand(B, num_interventions)
            
            intervened_effects = self.propagation_simulator(
                multi_scale_features, fused_causal_matrix,
                intervention_nodes, intervention_values
            )
            
            original_pred = causal_effects.squeeze(-1)
            intervened_pred = intervened_effects.squeeze(-1)
            
            validity_scores = self.intervention_validator.compute_intervention_effect(
                original_pred.transpose(1, 2),
                intervened_pred.transpose(1, 2),
                intervention_nodes,
                fused_causal_matrix
            )
        
        attention_output = self._causal_modulated_attention(
            multi_scale_features, fused_causal_matrix, causal_effects
        )
        
        return attention_output, causal_matrices, validity_scores

class HourBeltBlock(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, num_time_segments: int, head_s: int, 
                 head_t: int, d_ff_belt: int, dropout: float = 0.1, 
                 adj_matrix: Optional[torch.Tensor] = None, rank: int = None):
        super().__init__()
        self.d_model = d_model
        self.num_nodes = num_nodes
        self.head_t = head_t
        
        assert d_model % head_t == 0, f"d_model {d_model} must be divisible by head_t {head_t}"
        self.d_kt = d_model // head_t
        
        self.te_causgat = TECausGAT(
            d_model, num_nodes, num_time_segments, head_s, dropout, 
            adj_matrix=adj_matrix, rank=rank
        )
        
        self.temporal_attention = nn.MultiheadAttention(d_model, head_t, dropout=dropout, batch_first=True)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_ff_belt, bias=True),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff_belt, d_model, bias=True)
        )
    
    def forward(self, x: torch.Tensor, time_indices: torch.Tensor, 
                perform_intervention: bool = False,
                num_interventions: int = 2) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        B, seq_len, N, _ = x.shape
        
        spatial_out, causal_matrices, validity_scores = self.te_causgat(
            x, time_indices, perform_intervention, num_interventions
        )
        x = self.norm1(x + self.dropout(spatial_out))
        
        x_temporal = x.reshape(B * N, seq_len, self.d_model)
        
        temporal_out, temporal_weights = self.temporal_attention(
            x_temporal, x_temporal, x_temporal, need_weights=True, average_attn_weights=False
        )
        temporal_out = temporal_out.reshape(B, seq_len, N, self.d_model)
        
        if temporal_weights is not None:
            temporal_attention_weights = temporal_weights.reshape(B, N, self.head_t, seq_len, seq_len)
        else:
            temporal_attention_weights = None
        
        x = self.norm2(x + self.dropout(temporal_out))
        
        ff_out = self.feedforward(x)
        x = self.norm3(x + self.dropout(ff_out))
        
        return x, causal_matrices, temporal_attention_weights, validity_scores

class WeekBeltBlock(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, num_time_segments: int, head_s: int, 
                 head_t: int, d_ff_belt: int, dropout: float = 0.1, 
                 adj_matrix: Optional[torch.Tensor] = None, rank: int = None):
        super().__init__()
        self.d_model = d_model
        self.num_nodes = num_nodes
        self.head_t = head_t
        
        assert d_model % head_t == 0, f"d_model {d_model} must be divisible by head_t {head_t}"
        self.d_kt = d_model // head_t
        
        self.te_causgat = TECausGAT(
            d_model, num_nodes, num_time_segments, head_s, dropout, 
            adj_matrix=adj_matrix, rank=rank
        )
        
        self.temporal_attention = nn.MultiheadAttention(d_model, head_t, dropout=dropout, batch_first=True)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_ff_belt, bias=True),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff_belt, d_model, bias=True)
        )
    
    def forward(self, x: torch.Tensor, time_indices: torch.Tensor, 
                perform_intervention: bool = False,
                num_interventions: int = 2) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        B, seq_len, N, _ = x.shape
        
        spatial_out, causal_matrices, validity_scores = self.te_causgat(
            x, time_indices, perform_intervention, num_interventions
        )
        x = self.norm1(x + self.dropout(spatial_out))
        
        x_temporal = x.reshape(B * N, seq_len, self.d_model)
        
        temporal_out, temporal_weights = self.temporal_attention(
            x_temporal, x_temporal, x_temporal, need_weights=True, average_attn_weights=False
        )
        temporal_out = temporal_out.reshape(B, seq_len, N, self.d_model)
        
        if temporal_weights is not None:
            temporal_attention_weights = temporal_weights.reshape(B, N, self.head_t, seq_len, seq_len)
        else:
            temporal_attention_weights = None
        
        x = self.norm2(x + self.dropout(temporal_out))
        
        ff_out = self.feedforward(x)
        x = self.norm3(x + self.dropout(ff_out))
        
        return x, causal_matrices, temporal_attention_weights, validity_scores

class DayBeltBlock(nn.Module):
    def __init__(self, d_model: int, num_nodes: int, num_time_segments: int, head_s: int, 
                 head_t: int, d_ff_belt: int, dropout: float = 0.1, 
                 adj_matrix: Optional[torch.Tensor] = None, rank: int = None):
        super().__init__()
        self.d_model = d_model
        self.num_nodes = num_nodes
        self.head_t = head_t
        
        assert d_model % head_t == 0, f"d_model {d_model} must be divisible by head_t {head_t}"
        self.d_kt = d_model // head_t
        
        self.te_causgat = TECausGAT(
            d_model, num_nodes, num_time_segments, head_s, dropout, 
            adj_matrix=adj_matrix, rank=rank
        )
        
        self.temporal_attention = nn.MultiheadAttention(d_model, head_t, dropout=dropout, batch_first=True)
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        
        self.dropout = nn.Dropout(dropout)
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_ff_belt, bias=True),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff_belt, d_model, bias=True)
        )
    
    def forward(self, x: torch.Tensor, time_indices: torch.Tensor, 
                perform_intervention: bool = False,
                num_interventions: int = 2) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor], Optional[torch.Tensor]]:
        B, seq_len, N, _ = x.shape
        
        spatial_out, causal_matrices, validity_scores = self.te_causgat(
            x, time_indices, perform_intervention, num_interventions
        )
        x = self.norm1(x + self.dropout(spatial_out))
        
        x_temporal = x.reshape(B * N, seq_len, self.d_model)
        
        temporal_out, temporal_weights = self.temporal_attention(
            x_temporal, x_temporal, x_temporal, need_weights=True, average_attn_weights=False
        )
        temporal_out = temporal_out.reshape(B, seq_len, N, self.d_model)
        
        if temporal_weights is not None:
            temporal_attention_weights = temporal_weights.reshape(B, N, self.head_t, seq_len, seq_len)
        else:
            temporal_attention_weights = None
        
        x = self.norm2(x + self.dropout(temporal_out))
        
        ff_out = self.feedforward(x)
        x = self.norm3(x + self.dropout(ff_out))
        
        return x, causal_matrices, temporal_attention_weights, validity_scores

class MultiPeriodFusion(nn.Module):
    def __init__(self, d_model: int, head_f: int, d_ff_fusion: int, dropout: float = 0.1):
        super().__init__()
        self.d_model = d_model
        self.head_f = head_f
        
        self.fusion_attention = nn.MultiheadAttention(d_model, head_f, dropout=dropout, batch_first=True)
        
        self.period_weight_net = nn.Sequential(
            nn.Linear(d_model * 3, d_model),
            nn.ReLU(),
            nn.Linear(d_model, 3),
            nn.Softmax(dim=-1)
        )
        
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.feedforward = nn.Sequential(
            nn.Linear(d_model, d_ff_fusion, bias=True),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff_fusion, d_model, bias=True)
        )
    
    def forward(self, x_h: torch.Tensor, x_w: torch.Tensor, x_d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        B, seq_len, N, d_model = x_h.shape
        
        x_stack = torch.stack([x_h, x_w, x_d], dim=2)
        
        x_concat_for_weight = torch.cat([x_h, x_w, x_d], dim=-1)
        period_weights = self.period_weight_net(x_concat_for_weight)
        
        x_reshaped = x_stack.reshape(B * seq_len * N, 3, d_model)
        
        attn_out, attn_weights = self.fusion_attention(x_reshaped, x_reshaped, x_reshaped)
        
        attn_out = attn_out.reshape(B, seq_len, N, 3, d_model)
        
        x_stack = x_stack.permute(0, 1, 3, 2, 4)
        x_stack = self.norm1(x_stack + self.dropout(attn_out))
        
        weights_expanded = period_weights.unsqueeze(-1)
        weighted_sum = torch.sum(x_stack * weights_expanded, dim=3)
        
        ff_out = self.feedforward(weighted_sum)
        fused_output = self.norm2(weighted_sum + self.dropout(ff_out))
        
        return fused_output, attn_weights

class OrionModel(nn.Module):
    def __init__(self, config: dict, adj_matrix: Optional[np.ndarray] = None):
        super().__init__()
        
        self.num_nodes = config['num_of_vertices']
        self.in_channels = config['in_channels']
        self.target_len = config['target_len']
        self.source_len = config['source_len']
        self.d_model = config['d_model']
        self.n_belt_block = config['n_belt_block']
        self.num_time_segments = config['num_time_segments']
        
        self.causal_rank = config.get('causal_rank', min(32, self.num_nodes // 4))
        
        if adj_matrix is not None:
            adj_matrix_tensor = torch.from_numpy(adj_matrix).float()
        else:
            adj_matrix_tensor = None
        
        self.embedding_process = EmbeddingProcessModule(
            self.in_channels, self.d_model, config['d_ff_emb']
        )
        
        self.pos_encoding = PositionalEncoding(self.d_model)
        
        self.hour_belt_blocks = nn.ModuleList([
            HourBeltBlock(
                self.d_model, self.num_nodes, self.num_time_segments,
                config['head_s'], config['head_t'], config['d_ff_belt'],
                config['dropout'], adj_matrix=adj_matrix_tensor, rank=self.causal_rank
            ) for _ in range(self.n_belt_block)
        ])
        
        self.week_belt_blocks = nn.ModuleList([
            WeekBeltBlock(
                self.d_model, self.num_nodes, self.num_time_segments,
                config['head_s'], config['head_t'], config['d_ff_belt'],
                config['dropout'], adj_matrix=adj_matrix_tensor, rank=self.causal_rank
            ) for _ in range(self.n_belt_block)
        ])
        
        self.day_belt_blocks = nn.ModuleList([
            DayBeltBlock(
                self.d_model, self.num_nodes, self.num_time_segments,
                config['head_s'], config['head_t'], config['d_ff_belt'],
                config['dropout'], adj_matrix=adj_matrix_tensor, rank=self.causal_rank
            ) for _ in range(self.n_belt_block)
        ])
        
        self.fusion_module = MultiPeriodFusion(
            self.d_model, config['head_f'], config['d_ff_fusion'], config['dropout']
        )
        
        self.prediction_head = nn.Sequential(
            nn.Linear(self.d_model, config['d_ff_reverse'], bias=True),
            nn.ReLU(),
            nn.Dropout(config['dropout']),
            nn.Linear(config['d_ff_reverse'], config['d_ff_reverse'] // 2, bias=True),
            nn.ReLU(),
            nn.Linear(config['d_ff_reverse'] // 2, self.target_len, bias=True)
        )
        
    def forward(self, x_h: torch.Tensor, x_w: torch.Tensor, x_d: torch.Tensor, 
            time_indices: torch.Tensor, perform_intervention: bool = False,
            num_interventions: int = 2) -> Dict[str, torch.Tensor]:
        B = x_h.shape[0]
        device = x_h.device
        
        x_h_emb, x_w_emb, x_d_emb = self.embedding_process(x_h, x_w, x_d)
        
        pos_enc = self.pos_encoding(self.source_len)
        x_h_emb = x_h_emb + pos_enc
        x_w_emb = x_w_emb + pos_enc
        x_d_emb = x_d_emb + pos_enc
        
        all_causal_matrices = []
        all_temporal_attention_weights = []
        all_validity_scores = []
        
        h_out = x_h_emb
        for belt_idx in range(self.n_belt_block):
            do_intervention = perform_intervention and (belt_idx == self.n_belt_block - 1)
            h_out, h_causal, h_temporal_attn, h_validity = self.hour_belt_blocks[belt_idx](
                h_out, time_indices, do_intervention, num_interventions
            )
            if belt_idx == self.n_belt_block - 1:
                all_causal_matrices.append(h_causal)
                all_temporal_attention_weights.append(h_temporal_attn)
                if h_validity is not None:
                    all_validity_scores.append(h_validity)
        
        w_out = x_w_emb
        for belt_idx in range(self.n_belt_block):
            do_intervention = perform_intervention and (belt_idx == self.n_belt_block - 1)
            w_out, w_causal, w_temporal_attn, w_validity = self.week_belt_blocks[belt_idx](
                w_out, time_indices, do_intervention, num_interventions
            )
            if belt_idx == self.n_belt_block - 1:
                all_causal_matrices.append(w_causal)
                all_temporal_attention_weights.append(w_temporal_attn)
                if w_validity is not None:
                    all_validity_scores.append(w_validity)
        
        d_out = x_d_emb
        for belt_idx in range(self.n_belt_block):
            do_intervention = perform_intervention and (belt_idx == self.n_belt_block - 1)
            d_out, d_causal, d_temporal_attn, d_validity = self.day_belt_blocks[belt_idx](
                d_out, time_indices, do_intervention, num_interventions
            )
            if belt_idx == self.n_belt_block - 1:
                all_causal_matrices.append(d_causal)
                all_temporal_attention_weights.append(d_temporal_attn)
                if d_validity is not None:
                    all_validity_scores.append(d_validity)
        
        fused_output, fusion_weights = self.fusion_module(h_out, w_out, d_out)
        
        aggregated_features = fused_output.mean(dim=1)
        
        predictions = self.prediction_head(aggregated_features)
        
        validity_scores = None
        if all_validity_scores:
            validity_scores = torch.stack(all_validity_scores).mean(dim=0)
        
        return {
            'predictions': predictions,
            'causal_matrices': all_causal_matrices,
            'fusion_weights': fusion_weights,
            'temporal_attention_weights': all_temporal_attention_weights,
            'validity_scores': validity_scores
        }

    def compute_loss(self, predictions: torch.Tensor, targets: torch.Tensor, 
                 causal_matrices: List, validity_scores: Optional[torch.Tensor] = None,
                 lambda_causal: float = 0.1) -> Dict[str, torch.Tensor]:
        pred_loss = F.l1_loss(predictions, targets)
        
        causal_loss = torch.tensor(0.0, device=predictions.device, requires_grad=True)
        
        if causal_matrices:
            all_causal_losses = []
            
            for matrices_list in causal_matrices:
                if matrices_list is not None:
                    if isinstance(matrices_list, list):
                        for mat in matrices_list:
                            if mat is not None and hasattr(mat, 'requires_grad'):
                                sparsity = torch.mean(torch.abs(mat))
                                all_causal_losses.append(sparsity)
                                
                                if mat.dim() >= 3:
                                    mat_2d = mat[0] if mat.dim() == 4 else mat
                                    if mat_2d.dim() == 3:
                                        for t in range(mat_2d.shape[0]):
                                            A = mat_2d[t] / self.num_nodes
                                            A2 = torch.matmul(A, A)
                                            A3 = torch.matmul(A2, A)
                                            dag_loss = torch.abs(torch.trace(A3))
                                            all_causal_losses.append(dag_loss * 0.1)
                    
                    elif hasattr(matrices_list, 'requires_grad'):
                        sparsity = torch.mean(torch.abs(matrices_list))
                        all_causal_losses.append(sparsity)
            
            if all_causal_losses:
                causal_loss = torch.stack(all_causal_losses).mean()
        
        validity_bonus = torch.tensor(0.0, device=predictions.device)
        if validity_scores is not None:
            avg_validity = torch.mean(validity_scores)
            validity_bonus = -0.5 * avg_validity
            causal_loss = causal_loss + validity_bonus
        
        causal_loss_scaled = causal_loss * lambda_causal
        
        total_loss = pred_loss + causal_loss_scaled
        
        causal_contribution = (causal_loss_scaled.abs() / (pred_loss + 1e-8)).item() * 100
        
        return {
            'total_loss': total_loss,
            'pred_loss': pred_loss,
            'causal_loss': causal_loss,
            'cf_loss': torch.zeros_like(causal_loss),
            'causal_contribution': causal_contribution,
            'validity_bonus': validity_bonus.item() if validity_scores is not None else 0.0
        }
    
    def get_causal_insights(self) -> Dict[str, torch.Tensor]:
        insights = {}
        
        insights['hour_prior'] = torch.sigmoid(self.hour_belt_blocks[0].te_causgat.causal_fusion.prior_causal)
        insights['week_prior'] = torch.sigmoid(self.week_belt_blocks[0].te_causgat.causal_fusion.prior_causal)
        insights['day_prior'] = torch.sigmoid(self.day_belt_blocks[0].te_causgat.causal_fusion.prior_causal)
        
        insights['hour_confidence'] = torch.sigmoid(self.hour_belt_blocks[0].te_causgat.causal_fusion.prior_confidence)
        insights['week_confidence'] = torch.sigmoid(self.week_belt_blocks[0].te_causgat.causal_fusion.prior_confidence)
        insights['day_confidence'] = torch.sigmoid(self.day_belt_blocks[0].te_causgat.causal_fusion.prior_confidence)
        
        return insights

def create_orion_model(config: dict, adj_matrix: Optional[np.ndarray] = None) -> OrionModel:
    return OrionModel(config, adj_matrix)