"""
Sparse causal graph module
Implements true causally-constrained sparse graph, including sparsity constraints and causal window limitations
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Dict, Any, Tuple
import numpy as np


class SparseCausalGraph(nn.Module):
    """Sparse causal graph module"""
    
    def __init__(self, num_nodes: int, feature_dim: int = 256, 
                 l1_penalty: float = 0.01, gumbel_temperature: float = 1.0, 
                 sparsity_threshold: float = 0.1, max_lag: int = 24, 
                 min_lag: int = 1, causal_hidden_dim: int = 128):
        super().__init__()
        self.num_nodes = num_nodes
        self.feature_dim = feature_dim
        self.l1_penalty = l1_penalty
        self.gumbel_temperature = gumbel_temperature
        self.sparsity_threshold = sparsity_threshold
        self.max_lag = max_lag
        self.min_lag = min_lag
        self.causal_hidden_dim = causal_hidden_dim
        
        # Learnable adjacency matrix logits
        self.adjacency_logits = nn.Parameter(
            torch.randn(num_nodes, num_nodes) * 0.1  # Small initialization value
        )
        
        # Causal graph convolution layer
        self.causal_conv = nn.Sequential(
            nn.Linear(feature_dim, causal_hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(causal_hidden_dim, feature_dim)
        )
        
        # Time encoder
        self.time_encoder = nn.Sequential(
            nn.Linear(1, 32),
            nn.ReLU(),
            nn.Linear(32, 64)
        )
        
        # Causal strength predictor
        self.causal_strength_predictor = nn.Sequential(
            nn.Linear(feature_dim * 2 + 64, causal_hidden_dim),  # Features + time difference
            nn.ReLU(),
            nn.Linear(causal_hidden_dim, 1)
        )
    
    def forward(self, features: torch.Tensor, 
                time_stamps: Optional[torch.Tensor] = None,
                return_adjacency: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Forward propagation
        
        Args:
            features: [batch, num_nodes, feature_dim]
            time_stamps: [batch, num_nodes] timestamp information (optional)
            return_adjacency: whether to return adjacency matrix
            
        Returns:
            causal_features: [batch, num_nodes, feature_dim]
            adjacency_matrix: [batch, num_nodes, num_nodes] or [num_nodes, num_nodes]
        """
        batch_size, num_nodes, feature_dim = features.shape
        
        # 1. Apply sparsity constraint
        adjacency_matrix = self._apply_sparsity_constraint()
        
        # 2. Apply causal window constraint
        if time_stamps is not None:
            adjacency_matrix = self._apply_causal_window_constraint(
                adjacency_matrix, time_stamps
            )
        
        # 3. Apply causal graph convolution
        causal_features = self._apply_causal_convolution(features, adjacency_matrix)
        
        # 4. Return results
        if return_adjacency:
            if batch_size > 1 and adjacency_matrix.dim() == 2:
                # For batch data, replicate adjacency matrix
                adjacency_matrix = adjacency_matrix.unsqueeze(0).expand(batch_size, -1, -1)
            return causal_features, adjacency_matrix
        else:
            return causal_features
    
    def _apply_sparsity_constraint(self) -> torch.Tensor:
        """Apply sparsity constraint"""
        # Method 1: Gumbel-sigmoid sampling (during training)
        if self.training:
            gumbel_noise = torch.rand_like(self.adjacency_logits)
            gumbel_noise = -torch.log(-torch.log(gumbel_noise + 1e-8) + 1e-8)
            sampled_logits = (self.adjacency_logits + gumbel_noise) / self.gumbel_temperature
            adjacency_matrix = torch.sigmoid(sampled_logits)
        else:
            # During inference, directly use sigmoid
            adjacency_matrix = torch.sigmoid(self.adjacency_logits)
        
        # Apply threshold sparsification
        adjacency_matrix = torch.where(
            adjacency_matrix > self.sparsity_threshold,
            adjacency_matrix,
            torch.zeros_like(adjacency_matrix)
        )
        
        # Ensure diagonal is 0 (no self-loops)
        adjacency_matrix = adjacency_matrix * (1 - torch.eye(self.num_nodes, device=adjacency_matrix.device))
        
        return adjacency_matrix
    
    def _apply_causal_window_constraint(self, adjacency_matrix: torch.Tensor, 
                                      time_stamps: torch.Tensor) -> torch.Tensor:
        """Apply causal window constraint"""
        batch_size, num_nodes = time_stamps.shape
        
        # Create time difference matrix
        time_diff = time_stamps.unsqueeze(-1) - time_stamps.unsqueeze(-2)
        # [batch, num_nodes, num_nodes]
        
        # Apply causal window constraint
        # Ensure causality follows time order: future cannot affect past
        causal_mask = (time_diff >= self.min_lag) & (time_diff <= self.max_lag)
        
        # Apply mask
        adjacency_matrix = adjacency_matrix * causal_mask.float()
        
        return adjacency_matrix
    
    def _apply_causal_convolution(self, features: torch.Tensor, 
                                adjacency_matrix: torch.Tensor) -> torch.Tensor:
        """Apply causal graph convolution"""
        batch_size, num_nodes, feature_dim = features.shape
        
        # Graph convolution: H' = σ(AHW)
        # where A is adjacency matrix, H is node features, W is learnable weights
        
        # Apply adjacency matrix
        aggregated_features = torch.matmul(adjacency_matrix, features)
        # [batch, num_nodes, feature_dim]
        
        # Through causal convolution layer
        causal_features = self.causal_conv(aggregated_features)
        
        # Residual connection
        causal_features = causal_features + features
        
        return causal_features
    
    def predict_causal_strength(self, features: torch.Tensor, 
                              time_stamps: torch.Tensor) -> torch.Tensor:
        """Predict causal strength"""
        batch_size, num_nodes, feature_dim = features.shape
        
        # Create all node pairs
        node_pairs = []
        time_diffs = []
        
        for i in range(num_nodes):
            for j in range(num_nodes):
                if i != j:  # Exclude self-loops
                    # Concatenate features
                    pair_features = torch.cat([features[:, i], features[:, j]], dim=-1)
                    node_pairs.append(pair_features)
                    
                    # Calculate time difference
                    time_diff = time_stamps[:, j] - time_stamps[:, i]
                    time_diffs.append(time_diff)
        
        if node_pairs:
            node_pairs = torch.stack(node_pairs, dim=1)  # [batch, n_pairs, feature_dim*2]
            time_diffs = torch.stack(time_diffs, dim=1)  # [batch, n_pairs]
            
            # Encode time difference
            time_encoding = self.time_encoder(time_diffs.unsqueeze(-1))  # [batch, n_pairs, 64]
            
            # Concatenate features and time encoding
            combined_features = torch.cat([node_pairs, time_encoding], dim=-1)
            
            # Predict causal strength
            causal_strengths = self.causal_strength_predictor(combined_features).squeeze(-1)
            # [batch, n_pairs]
            
            # Reshape to adjacency matrix format
            causal_strength_matrix = torch.zeros(batch_size, num_nodes, num_nodes, 
                                               device=features.device)
            
            pair_idx = 0
            for i in range(num_nodes):
                for j in range(num_nodes):
                    if i != j:
                        causal_strength_matrix[:, i, j] = causal_strengths[:, pair_idx]
                        pair_idx += 1
            
            return causal_strength_matrix
        else:
            return torch.zeros(batch_size, num_nodes, num_nodes, device=features.device)
    
    def get_sparsity_loss(self) -> torch.Tensor:
        """Calculate sparsity loss"""
        return self.l1_penalty * torch.norm(self.adjacency_logits, p=1)
    
    def get_causal_consistency_loss(self, time_stamps: torch.Tensor) -> torch.Tensor:
        """Calculate causal consistency loss"""
        # Ensure time order consistency
        time_diff = time_stamps.unsqueeze(-1) - time_stamps.unsqueeze(-2)
        
        # Penalty for future affecting past
        future_past_violations = torch.relu(-time_diff)  # Only penalize negative values (future affecting past)
        
        # Calculate degree of causality violation
        adjacency_matrix = torch.sigmoid(self.adjacency_logits)
        causal_violations = adjacency_matrix * future_past_violations
        
        return torch.mean(causal_violations)
    
    def get_adjacency_matrix(self) -> torch.Tensor:
        """Get current adjacency matrix"""
        return self._apply_sparsity_constraint()
    
    def get_sparsity_ratio(self) -> float:
        """Get sparsity ratio"""
        adjacency_matrix = self.get_adjacency_matrix()
        total_edges = self.num_nodes * (self.num_nodes - 1)  # Exclude self-loops
        active_edges = torch.sum(adjacency_matrix > self.sparsity_threshold).item()
        return 1.0 - (active_edges / total_edges)


class CausalWindowConstraint(nn.Module):
    """Causal window constraint module"""
    
    def __init__(self, max_lag: int = 24, min_lag: int = 1, 
                 strict_causality: bool = True):
        super().__init__()
        self.max_lag = max_lag
        self.min_lag = min_lag
        self.strict_causality = strict_causality
    
    def forward(self, adjacency_matrix: torch.Tensor, 
                time_stamps: torch.Tensor) -> torch.Tensor:
        """
        Apply causal window constraint
        
        Args:
            adjacency_matrix: [batch, num_nodes, num_nodes] or [num_nodes, num_nodes]
            time_stamps: [batch, num_nodes]
            
        Returns:
            constrained_adjacency: constrained adjacency matrix
        """
        if adjacency_matrix.dim() == 2:
            # [num_nodes, num_nodes] -> [1, num_nodes, num_nodes]
            adjacency_matrix = adjacency_matrix.unsqueeze(0)
            time_stamps = time_stamps.unsqueeze(0)
        
        batch_size, num_nodes = time_stamps.shape
        
        # Create time difference matrix
        time_diff = time_stamps.unsqueeze(-1) - time_stamps.unsqueeze(-2)
        # [batch, num_nodes, num_nodes]
        
        # Create causal window mask
        if self.strict_causality:
            # Strict causality: only allow past to affect future
            causal_mask = (time_diff >= self.min_lag) & (time_diff <= self.max_lag)
        else:
            # Relaxed causality: allow some degree of future affecting past
            causal_mask = (time_diff >= -self.max_lag) & (time_diff <= self.max_lag)
        
        # Apply mask
        constrained_adjacency = adjacency_matrix * causal_mask.float()
        
        # Remove batch dimension (if input is 2D)
        if adjacency_matrix.shape[0] == 1 and adjacency_matrix.dim() == 3:
            constrained_adjacency = constrained_adjacency.squeeze(0)
        
        return constrained_adjacency
    
    def get_causal_violations(self, adjacency_matrix: torch.Tensor, 
                            time_stamps: torch.Tensor) -> torch.Tensor:
        """Get causal violation situation"""
        time_diff = time_stamps.unsqueeze(-1) - time_stamps.unsqueeze(-2)
        
        if self.strict_causality:
            violations = torch.relu(-time_diff)  # Future affecting past
        else:
            violations = torch.relu(torch.abs(time_diff) - self.max_lag)  # Exceeding time window
        
        return violations


class GumbelSigmoidSampler(nn.Module):
    """Gumbel-Sigmoid sampler"""
    
    def __init__(self, temperature: float = 1.0, hard: bool = False):
        super().__init__()
        self.temperature = temperature
        self.hard = hard
    
    def forward(self, logits: torch.Tensor) -> torch.Tensor:
        """
        Gumbel-Sigmoid sampling
        
        Args:
            logits: input logits
            
        Returns:
            sampled_values: sampled values
        """
        if self.training:
            # Add Gumbel noise
            gumbel_noise = torch.rand_like(logits)
            gumbel_noise = -torch.log(-torch.log(gumbel_noise + 1e-8) + 1e-8)
            
            # Apply temperature scaling
            sampled_logits = (logits + gumbel_noise) / self.temperature
            
            # Sigmoid activation
            probs = torch.sigmoid(sampled_logits)
            
            if self.hard:
                # Hard sampling
                hard_probs = (probs > 0.5).float()
                # Straight-through estimator
                probs = hard_probs - probs.detach() + probs
            
            return probs
        else:
            # During inference, directly use sigmoid
            return torch.sigmoid(logits)


# Test functions
def test_sparse_causal_graph():
    """Test SparseCausalGraph"""
    batch_size = 2
    num_nodes = 10
    feature_dim = 64
    
    # Create test data
    features = torch.randn(batch_size, num_nodes, feature_dim)
    time_stamps = torch.arange(num_nodes, dtype=torch.float32).unsqueeze(0).expand(batch_size, -1)
    
    # Create model
    causal_graph = SparseCausalGraph(
        num_nodes=num_nodes,
        feature_dim=feature_dim,
        l1_penalty=0.01,
        max_lag=5,
        min_lag=1
    )
    
    # Test forward propagation
    causal_features, adjacency_matrix = causal_graph(features, time_stamps)
    
    print(f"SparseCausalGraph test:")
    print(f"  Input features: {features.shape}")
    print(f"  Output features: {causal_features.shape}")
    print(f"  Adjacency matrix: {adjacency_matrix.shape}")
    print(f"  Sparsity ratio: {causal_graph.get_sparsity_ratio():.4f}")
    print(f"  Sparsity loss: {causal_graph.get_sparsity_loss().item():.4f}")
    print(f"  Causal consistency loss: {causal_graph.get_causal_consistency_loss(time_stamps).item():.4f}")
    print()


def test_causal_window_constraint():
    """Test CausalWindowConstraint"""
    num_nodes = 8
    adjacency_matrix = torch.rand(num_nodes, num_nodes)
    time_stamps = torch.arange(num_nodes, dtype=torch.float32)
    
    # Create constraint module
    constraint = CausalWindowConstraint(max_lag=3, min_lag=1, strict_causality=True)
    
    # Apply constraint
    constrained_adjacency = constraint(adjacency_matrix, time_stamps)
    violations = constraint.get_causal_violations(adjacency_matrix, time_stamps)
    
    print(f"CausalWindowConstraint test:")
    print(f"  Original adjacency matrix shape: {adjacency_matrix.shape}")
    print(f"  Constrained adjacency matrix shape: {constrained_adjacency.shape}")
    print(f"  Causal violation count: {torch.sum(violations > 0).item()}")
    print()


def test_gumbel_sigmoid_sampler():
    """Test GumbelSigmoidSampler"""
    logits = torch.randn(5, 5)
    
    # Create sampler
    sampler = GumbelSigmoidSampler(temperature=1.0, hard=False)
    
    # Test sampling
    sampled_values = sampler(logits)
    
    print(f"GumbelSigmoidSampler test:")
    print(f"  Input logits shape: {logits.shape}")
    print(f"  Sampled values shape: {sampled_values.shape}")
    print(f"  Sampled values range: [{sampled_values.min().item():.4f}, {sampled_values.max().item():.4f}]")
    print()


if __name__ == "__main__":
    print("Testing sparse causal graph module...")
    test_sparse_causal_graph()
    test_causal_window_constraint()
    test_gumbel_sigmoid_sampler() 