import torch.nn as nn
import torch
import torch.nn.functional as F
import numpy as np
from torch.utils.checkpoint import checkpoint


class WarningVectorGenerator(nn.Module):
    """
    Generates 8-dimensional warning impact vectors from maritime warning text.
    Based on LLM-generated scores for different impact dimensions.
    """

    def __init__(self, text_embedding_dim=768, warning_vector_dim=8):
        super().__init__()
        self.warning_vector_dim = warning_vector_dim
        
        # Network to process LLM-generated warning scores
        self.warning_processor = nn.Sequential(
            nn.Linear(warning_vector_dim, 32),
            nn.ReLU(),
            nn.Linear(32, 64),
            nn.LayerNorm(64),
            nn.ReLU(),
            nn.Linear(64, warning_vector_dim),
            nn.Sigmoid()  # Normalize to [0,1] range
        )

    def forward(self, llm_warning_scores):
        """
        Args:
            llm_warning_scores: [batch_size, time_steps, num_nodes, 8]
                LLM-generated scores for 8 warning dimensions:
                Q1: Spatial Impact, Q2: Delay Risk, Q3: Reroute Need, Q4: Duration Impact
                Q5: Port Congestion, Q6: Cargo Threat, Q7: Speed Adjustment, Q8: Uncertainty Level

        Returns:
            processed_warnings: [batch_size, time_steps, num_nodes, 8]
                Processed warning impact vectors
        """
        # Process the LLM scores through the network
        processed_warnings = self.warning_processor(llm_warning_scores)
        return processed_warnings


class BackDoorCausalAdjustment(nn.Module):
    """
    Implements Back-door causal adjustment: A → B → {X, V} → D
    Where A: raw warnings, B: warning scores, X: historical flow, V: spatial context, D: future flow
    """

    def __init__(self, model_dim, warning_dim=8, hidden_dim=64, dropout=0.1):
        super().__init__()
        self.model_dim = model_dim
        self.warning_dim = warning_dim

        # Warning encoder (B encoder)
        self.warning_encoder = nn.Sequential(
            nn.Linear(warning_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Historical flow encoder (X encoder)
        self.flow_encoder = nn.Sequential(
            nn.Linear(model_dim, hidden_dim),
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # Spatial context encoder (V encoder) 
        self.spatial_encoder = nn.Sequential(
            nn.Linear(num_nodes * model_dim * 2, hidden_dim),  # 修改输入维度匹配spatial_input
            nn.LayerNorm(hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, hidden_dim),
        )

        # B→X causal pathway estimator
        self.b_to_x_estimator = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # B→V causal pathway estimator  
        self.b_to_v_estimator = nn.Sequential(
            nn.Linear(hidden_dim * 2, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, 1),
            nn.Sigmoid()
        )

        # X→D effect estimator
        self.x_to_d_estimator = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, model_dim),
        )

        # V→D effect estimator
        self.v_to_d_estimator = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, model_dim),
        )

    def forward(self, warning_features, flow_features, adjacency_matrix):
        """
        Implements Back-door adjustment: P(D|do(B)) = ∑_{x,v} P(X,V|B) P(D|X,V)

        Args:
            warning_features: [batch_size, time_steps, num_nodes, 8] - Warning impact scores (B)
            flow_features: [batch_size, time_steps, num_nodes, model_dim] - Historical flow (X)
            adjacency_matrix: [batch_size, time_steps, num_nodes, num_nodes] - Spatial relationships

        Returns:
            causal_effect: [batch_size, time_steps, num_nodes, model_dim] - Deconfounded causal effect
            causal_weights: [batch_size, time_steps, num_nodes, num_nodes] - Causal strength matrix
        """
        batch_size, time_steps, num_nodes, _ = flow_features.shape

        # Encode inputs
        B_encoded = self.warning_encoder(warning_features)  # [B, T, N, hidden_dim]
        X_encoded = self.flow_encoder(flow_features)        # [B, T, N, hidden_dim]

        # Prepare spatial context by combining adjacency and flow
        flow_expanded = flow_features.unsqueeze(3).expand(-1, -1, -1, num_nodes, -1)
        adj_expanded = adjacency_matrix.unsqueeze(-1).expand(-1, -1, -1, -1, flow_features.shape[-1])
        spatial_input = torch.cat([flow_expanded, adj_expanded], dim=-1)
        # 修改这一行，修复维度不匹配问题
        spatial_input = spatial_input.reshape(batch_size * time_steps * num_nodes, num_nodes * flow_features.shape[-1] * 2)
        V_encoded = self.spatial_encoder(spatial_input)     # [B, T, N, hidden_dim]
        # 重塑回原始维度
        V_encoded = V_encoded.reshape(batch_size, time_steps, num_nodes, -1)

        # Calculate causal pathway strengths
        # P(X|B)
        bx_input = torch.cat([B_encoded, X_encoded], dim=-1)
        b_to_x_strength = self.b_to_x_estimator(bx_input)   # [B, T, N, 1]

        # P(V|B)  
        bv_input = torch.cat([B_encoded, V_encoded], dim=-1)
        b_to_v_strength = self.b_to_v_estimator(bv_input)   # [B, T, N, 1]

        # Calculate effects: P(D|X) and P(D|V)
        x_to_d_effect = self.x_to_d_estimator(X_encoded)    # [B, T, N, model_dim]
        v_to_d_effect = self.v_to_d_estimator(V_encoded)    # [B, T, N, model_dim]

        # Back-door adjustment: combine pathways
        pathway_x = b_to_x_strength * x_to_d_effect         # [B, T, N, model_dim]
        pathway_v = b_to_v_strength * v_to_d_effect         # [B, T, N, model_dim]
        
        causal_effect = pathway_x + pathway_v               # [B, T, N, model_dim]

        # Generate causal weights matrix for attention
        causal_weights = torch.matmul(B_encoded, B_encoded.transpose(-1, -2))  # [B, T, N, N]
        causal_weights = torch.sigmoid(causal_weights)

        return causal_effect, causal_weights


class ODERipplePropagation(nn.Module):
    """
    Models warning propagation as continuous-time dynamical system using ODEs.
    Captures both immediate and lingering effects through adaptive decay.
    """

    def __init__(self, num_nodes, hidden_dim=64, num_integration_steps=10):
        super().__init__()
        self.num_nodes = num_nodes
        self.hidden_dim = hidden_dim
        self.num_integration_steps = num_integration_steps

        # Learnable geographical embeddings
        self.geo_embeddings = nn.Parameter(torch.randn(num_nodes, 2))

        # Dissipation rate parameters
        self.dissipation_net = nn.Sequential(
            nn.Linear(hidden_dim + 2, 32),  # hidden_dim + capacity + weather
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Ensure positive dissipation
        )

        # Propagation rate parameters
        self.propagation_net = nn.Sequential(
            nn.Linear(hidden_dim + 2, 32),  # hidden_dim + connectivity + flow
            nn.ReLU(),
            nn.Linear(32, 1),
            nn.Softplus()  # Ensure positive propagation
        )

        # Adaptive decay function parameters
        self.decay_params = nn.Parameter(torch.tensor([0.5, 0.1, 0.05]))  # alpha, beta, gamma

    def compute_distance_matrix(self):
        """Compute geographical distance matrix from learnable embeddings"""
        dist_matrix = torch.cdist(self.geo_embeddings, self.geo_embeddings)
        return dist_matrix

    def compute_propagation_strength(self, dist_matrix, time_delay, D_t):
        """
        R_ij(t) = exp(-d_ij^2 / D(t)) * exp(-λ * h_ij)
        """
        # Network hops (simplified as normalized distance)
        h_ij = dist_matrix / dist_matrix.max()
        
        # Distance-based decay
        spatial_decay = torch.exp(-dist_matrix ** 2 / D_t)
        network_decay = torch.exp(-0.1 * h_ij)  # λ = 0.1
        
        return spatial_decay * network_decay

    def adaptive_decay(self, x, delta_t):
        """
        g(x, Δt) = x * [α*exp(-β*Δt) + (1-α)*exp(-γ*Δt²)]
        """
        alpha, beta, gamma = self.decay_params
        immediate_decay = alpha * torch.exp(-beta * delta_t)
        lingering_decay = (1 - alpha) * torch.exp(-gamma * delta_t ** 2)
        return x * (immediate_decay + lingering_decay)

    def ode_step(self, x, causal_weights, port_capacity, weather_conditions):
        """
        dx_i/dt = -σ_i(t)*x_i(t) + Σ_j ρ_ij(t)*R_ij(t)*g(x_j(t), Δt_ij)
        """
        batch_size, time_steps, num_nodes, hidden_dim = x.shape
        device = x.device

        # Compute distance matrix
        dist_matrix = self.compute_distance_matrix()
        
        # Compute dissipation rates σ_i(t)
        capacity_weather = torch.cat([port_capacity, weather_conditions], dim=-1)
        dissipation_input = torch.cat([x, capacity_weather], dim=-1)
        sigma = self.dissipation_net(dissipation_input)  # [B, T, N, 1]

        # Compute propagation rates ρ_ij(t)
        connectivity = causal_weights.sum(dim=-1, keepdim=True)  # [B, T, N, 1]
        flow_intensity = x.norm(dim=-1, keepdim=True)           # [B, T, N, 1]
        prop_input = torch.cat([x, connectivity, flow_intensity], dim=-1)
        rho = self.propagation_net(prop_input)  # [B, T, N, 1]

        # Time delays (simplified as distance-based)
        delta_t = dist_matrix / dist_matrix.max()
        
        # Propagation strength R_ij(t)
        D_t = 1.0  # Simplified diffusion coefficient
        R_ij = self.compute_propagation_strength(dist_matrix, delta_t, D_t)
        R_ij = R_ij.unsqueeze(0).unsqueeze(0).expand(batch_size, time_steps, -1, -1)

        # Adaptive decay g(x_j, Δt_ij)
        x_expanded = x.unsqueeze(2).expand(-1, -1, num_nodes, -1, -1)  # [B, T, N, N, H]
        delta_t_expanded = delta_t.unsqueeze(0).unsqueeze(0).unsqueeze(-1).expand(batch_size, time_steps, -1, -1, hidden_dim)
        decayed_x = self.adaptive_decay(x_expanded, delta_t_expanded)

        # External propagation term
        rho_expanded = rho.unsqueeze(3).expand(-1, -1, -1, num_nodes, -1)  # [B, T, N, N, H]
        R_ij_expanded = R_ij.unsqueeze(-1).expand(-1, -1, -1, -1, hidden_dim)
        
        propagation_term = (rho_expanded * R_ij_expanded * decayed_x).sum(dim=2)  # [B, T, N, H]

        # Local dissipation term
        dissipation_term = sigma * x  # [B, T, N, H]

        # ODE: dx/dt = -dissipation + propagation
        dx_dt = -dissipation_term + propagation_term

        return dx_dt

    def forward(self, initial_state, causal_weights, port_capacity, weather_conditions, dt=0.1):
        """
        Solve ODE using Euler method for simplicity
        
        Args:
            initial_state: [batch_size, time_steps, num_nodes, hidden_dim]
            causal_weights: [batch_size, time_steps, num_nodes, num_nodes]
            port_capacity: [batch_size, time_steps, num_nodes, 1] 
            weather_conditions: [batch_size, time_steps, num_nodes, 1]
            dt: Integration step size
        """
        x = initial_state
        
        for step in range(self.num_integration_steps):
            dx_dt = self.ode_step(x, causal_weights, port_capacity, weather_conditions)
            x = x + dt * dx_dt
            
            # Check convergence
            if torch.norm(dx_dt) < 1e-6 * torch.norm(initial_state):
                break

        return x



class RippleNet(nn.Module):
    """
    RippleNet: Learning Causal Maritime Dynamics for Forecasting Warning-Induced Ripple Effects
    Integrates warning vectors, back-door adjustment, and ODE-based ripple propagation
    """

    def __init__(
        self,
        num_nodes,
        in_steps=4,
        out_steps=4, 
        steps_per_day=2,  # 12-hour intervals = 2 per day
        input_dim=3,
        output_dim=1,
        warning_dim=8,  # 8-dimensional warning vectors
        input_embedding_dim=24,
        tod_embedding_dim=24,
        dow_embedding_dim=24,
        spatial_embedding_dim=0,
        adaptive_embedding_dim=24,
        feed_forward_dim=256,
        num_heads=4,
        num_layers=3,
        dropout=0.1,
        use_mixed_proj=True,
    ):
        super().__init__()

        # Basic parameters
        self.num_nodes = num_nodes
        self.in_steps = in_steps
        self.out_steps = out_steps
        self.steps_per_day = steps_per_day
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.warning_dim = warning_dim
        self.model_dim = (
            input_embedding_dim + tod_embedding_dim + dow_embedding_dim + 
            spatial_embedding_dim + adaptive_embedding_dim
        )
        self.adaptive_embedding_dim = adaptive_embedding_dim
        self.tod_embedding_dim = tod_embedding_dim
        self.dow_embedding_dim = dow_embedding_dim
        self.spatial_embedding_dim = spatial_embedding_dim
        # Input projections
        self.input_proj = nn.Linear(input_dim, input_embedding_dim)
        
        # Warning vector generator
        self.warning_generator = WarningVectorGenerator(
            warning_vector_dim=warning_dim
        )

        # Temporal embeddings
        if tod_embedding_dim > 0:
            self.tod_embedding = nn.Embedding(steps_per_day, tod_embedding_dim)
        if dow_embedding_dim > 0:
            self.dow_embedding = nn.Embedding(7, dow_embedding_dim)

        # Spatial embeddings
        if spatial_embedding_dim > 0:
            self.node_emb = nn.Parameter(torch.empty(num_nodes, spatial_embedding_dim))
            nn.init.xavier_uniform_(self.node_emb)

        # Adaptive embeddings
        if adaptive_embedding_dim > 0:
            self.adaptive_embedding = nn.Parameter(
                torch.empty(in_steps, num_nodes, adaptive_embedding_dim)
            )
            nn.init.xavier_uniform_(self.adaptive_embedding)

        # Core causal modules
        self.back_door_adjuster = BackDoorCausalAdjustment(
            model_dim=self.model_dim,
            warning_dim=warning_dim,
            hidden_dim=64,
            dropout=dropout
        )

        self.ode_ripple = ODERipplePropagation(
            num_nodes=num_nodes,
            hidden_dim=self.model_dim,
            num_integration_steps=10
        )

        # Attention layers
        self.causal_attention_layers = nn.ModuleList([
            CausalSelfAttentionLayer(
                self.model_dim, feed_forward_dim, num_heads, dropout
            ) for _ in range(num_layers)
        ])

        # Output projection
        if use_mixed_proj:
            self.output_proj = nn.Linear(
                in_steps * self.model_dim, out_steps * output_dim
            )
        else:
            self.temporal_proj = nn.Linear(in_steps, out_steps)
            self.output_proj = nn.Linear(self.model_dim, output_dim)

        self.use_mixed_proj = use_mixed_proj

        # Additional networks for port characteristics
        self.capacity_estimator = nn.Linear(self.model_dim, 1)
        self.weather_estimator = nn.Linear(self.model_dim, 1)

    def forward(self, x, warning_scores, adjacency_matrix=None):
        """
        Forward pass implementing the complete RippleNet pipeline

        Args:
            x: [batch_size, in_steps, num_nodes, input_dim] - Flow features
            warning_scores: [batch_size, in_steps, num_nodes, warning_dim] - LLM warning scores
            adjacency_matrix: [batch_size, in_steps, num_nodes, num_nodes] - Optional adjacency

        Returns:
            out: [batch_size, out_steps, num_nodes, output_dim] - Predicted flows
        """
        batch_size = x.shape[0]

        # Extract temporal features
        if self.tod_embedding_dim > 0:
            tod = x[..., -1]  # Assuming last feature is time-of-day
        if self.dow_embedding_dim > 0:
            dow = x[..., -2]  # Assuming second-last is day-of-week

        x_input = x[..., :self.input_dim]

        # Project input features
        x_proj = self.input_proj(x_input)
        features = [x_proj]

        # Add temporal embeddings
        if self.tod_embedding_dim > 0:
            tod_emb = self.tod_embedding((tod * self.steps_per_day).long())
            features.append(tod_emb)

        if self.dow_embedding_dim > 0:
            dow_emb = self.dow_embedding(dow.long())
            features.append(dow_emb)

        # Add spatial embeddings
        if self.spatial_embedding_dim > 0:
            spatial_emb = self.node_emb.expand(
                batch_size, self.in_steps, *self.node_emb.shape
            )
            features.append(spatial_emb)

        # Add adaptive embeddings
        if self.adaptive_embedding_dim > 0:
            adp_emb = self.adaptive_embedding.expand(
                size=(batch_size, *self.adaptive_embedding.shape)
            )
            features.append(adp_emb)

        # Concatenate all features
        x = torch.cat(features, dim=-1)  # [batch_size, in_steps, num_nodes, model_dim]

        # Process warning vectors
        processed_warnings = self.warning_generator(warning_scores)

        # Generate adjacency matrix if not provided
        if adjacency_matrix is None:
            # Create simple adjacency based on feature similarity
            similarity = torch.matmul(x, x.transpose(-1, -2))
            adjacency_matrix = torch.sigmoid(similarity)

        # Apply back-door causal adjustment
        causal_effect, causal_weights = self.back_door_adjuster(
            processed_warnings, x, adjacency_matrix
        )

        # Estimate port characteristics for ODE
        port_capacity = torch.sigmoid(self.capacity_estimator(x))     # [B, T, N, 1]
        weather_conditions = torch.sigmoid(self.weather_estimator(x)) # [B, T, N, 1]

        # Apply ODE-based ripple propagation
        ripple_propagated = self.ode_ripple(
            causal_effect, causal_weights, port_capacity, weather_conditions
        )

        # Integrate ripple effects with original features
        x_enhanced = x + ripple_propagated

        # Apply causal attention layers
        for attn_layer in self.causal_attention_layers:
            x_enhanced = attn_layer(x_enhanced, causal_weights, dim=2)

        # Output projection
        if self.use_mixed_proj:
            out = x_enhanced.transpose(1, 2)  # [batch_size, num_nodes, in_steps, model_dim]
            out = out.reshape(
                batch_size, self.num_nodes, self.in_steps * self.model_dim
            )
            out = self.output_proj(out).view(
                batch_size, self.num_nodes, self.out_steps, self.output_dim
            )
            out = out.transpose(1, 2)  # [batch_size, out_steps, num_nodes, output_dim]
        else:
            out = x_enhanced.transpose(1, 3)  # [batch_size, model_dim, num_nodes, in_steps]
            out = self.temporal_proj(out)     # [batch_size, model_dim, num_nodes, out_steps]
            out = self.output_proj(out.transpose(1, 3))  # [batch_size, out_steps, num_nodes, output_dim]

        return out



class CausalSelfAttentionLayer(nn.Module):
    def __init__(self, model_dim, feed_forward_dim=256, num_heads=4, dropout=0.1):
        super().__init__()
        
        self.causal_attn = CausalAttentionLayer(model_dim, num_heads)
        self.feed_forward = nn.Sequential(
            nn.Linear(model_dim, feed_forward_dim),
            nn.ReLU(),
            nn.Linear(feed_forward_dim, model_dim),
        )
        self.ln1 = nn.LayerNorm(model_dim)
        self.ln2 = nn.LayerNorm(model_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, causal_adj, dim=-2):
        x = x.transpose(dim, -2)
        
        # Causal attention
        x_norm = self.ln1(x)
        residual = x
        out = self.causal_attn(x_norm, x_norm, x_norm, causal_adj)
        out = self.dropout1(out)
        out = residual + out

        # Feed forward
        out_norm = self.ln2(out)
        residual = out
        out = self.feed_forward(out_norm)
        out = self.dropout2(out)
        out = residual + out

        out = out.transpose(dim, -2)
        return out


class CausalAttentionLayer(nn.Module):
    def __init__(self, model_dim, num_heads=4):
        super().__init__()
        
        self.model_dim = model_dim
        self.num_heads = num_heads
        self.head_dim = model_dim // num_heads

        self.FC_Q = nn.Linear(model_dim, model_dim)
        self.FC_K = nn.Linear(model_dim, model_dim)
        self.FC_V = nn.Linear(model_dim, model_dim)
        self.out_proj = nn.Linear(model_dim, model_dim)

    def forward(self, query, key, value, causal_adj):
        batch_size = query.shape[0]
        tgt_length = query.shape[-2]
        src_length = key.shape[-2]

        # Standard attention processing
        query = self.FC_Q(query)
        key = self.FC_K(key)
        value = self.FC_V(value)

        # Split heads
        query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)
        key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
        value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)

        key = key.transpose(-1, -2)

        # Calculate attention scores
        attn_score = (query @ key) / self.head_dim**0.5

        # Apply causal adjacency matrix
        causal_adj_expanded = causal_adj.repeat(self.num_heads, 1, 1, 1)
        attn_score = attn_score * causal_adj_expanded

        # Apply softmax
        attn_score = torch.softmax(attn_score, dim=-1)

        # Get output
        out = attn_score @ value
        out = torch.cat(torch.split(out, batch_size, dim=0), dim=-1)
        out = self.out_proj(out)

        return out

