#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

# Import other modules
from .graph_learner import GraphLearner
from .intervention import InterventionModule
# Counterfactual module is likely used for estimation, not directly in forward pass
# from .counterfactual import CounterfactualModule 

class ECAM(nn.Module):
    """Endogenous Causal Attention Mechanism (ECAM)"""
    def __init__(self, d_model, n_heads, dropout=0.1, graph_reg=0.01):
        """Initialize the ECAM layer.

        Args:
            d_model (int): Dimension of the input/output embeddings.
            n_heads (int): Number of attention heads.
            dropout (float): Dropout probability.
            graph_reg (float): Regularization strength for graph learning.
        """
        super().__init__()
        if d_model % n_heads != 0:
            raise ValueError(
                f"The hidden size ({d_model}) is not a multiple of the number of heads ({n_heads})"
            )
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads
        self.graph_reg = graph_reg
        
        # Standard attention linear layers
        self.wq = nn.Linear(d_model, d_model)
        self.wk = nn.Linear(d_model, d_model)
        self.wv = nn.Linear(d_model, d_model)
        self.wo = nn.Linear(d_model, d_model)
        
        # Causal components
        # Instantiate the GraphLearner
        self.graph_learner = GraphLearner(d_model, graph_reg)
        # Instantiate the InterventionModule
        self.intervention_module = InterventionModule(d_model)
        # CounterfactualModule is likely used separately for estimation tasks
        # self.counterfactual_module = CounterfactualModule(d_model)
        
        self.dropout = nn.Dropout(dropout)

    def modulate_with_graph(self, scores, G):
        """Modulate attention scores based on the learned causal graph G.
        
        Args:
            scores (torch.Tensor): Original attention scores (batch_size, n_heads, seq_len, seq_len).
            G (torch.Tensor): Learned causal graph adjacency matrix (batch_size, seq_len, seq_len).
                               Values represent edge strengths/probabilities (e.g., after sigmoid).

        Returns:
            torch.Tensor: Modulated attention scores (batch_size, n_heads, seq_len, seq_len).
        """
        # G is (batch_size, seq_len, seq_len). Unsqueeze to add head dimension for broadcasting.
        # This assumes the same learned graph structure applies across all heads.
        # If head-specific graphs were learned, G would be (batch_size, n_heads, seq_len, seq_len).
        if G.dim() == 3:
            causal_modulation = G.unsqueeze(1) # -> (batch_size, 1, seq_len, seq_len)
        elif G.dim() == 4: # Potentially head-specific graphs
             causal_modulation = G
        else:
            raise ValueError(f"Unexpected graph dimension: {G.dim()}")

        # Simple modulation: element-wise multiplication
        # Alternatives could be masking based on a threshold, or adding G (potentially scaled).
        modulated_scores = scores * causal_modulation
        return modulated_scores

    def forward(self, x, mask=None, intervention=None, return_graph=False):
        """Forward pass for ECAM.

        Args:
            x (torch.Tensor): Input tensor (batch_size, seq_len, d_model).
            mask (torch.Tensor, optional): Attention mask (batch_size, 1, seq_len, seq_len) or similar.
                                            Masked positions should be indicated by 0 or True.
            intervention (dict, optional): Dictionary specifying intervention details.
                                           Example: {"node_idx": 5, "value": tensor(...)}
            return_graph (bool): Whether to return the learned graph G.

        Returns:
            torch.Tensor: Output tensor (batch_size, seq_len, d_model).
            torch.Tensor: Attention weights (batch_size, n_heads, seq_len, seq_len).
            torch.Tensor (optional): Learned causal graph G (batch_size, seq_len, seq_len).
        """
        batch_size, seq_len, _ = x.size()

        # 1. Project inputs to Q, K, V
        # (batch_size, seq_len, d_model) -> (batch_size, n_heads, seq_len, d_k)
        q = self.wq(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        k = self.wk(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)
        v = self.wv(x).view(batch_size, seq_len, self.n_heads, self.d_k).transpose(1, 2)

        # 2. Learn Local Causal Graph (LCG)
        # G shape: (batch_size, seq_len, seq_len)
        G = self.graph_learner(x)
        # Potential graph loss (e.g., sparsity) can be calculated here or outside
        # graph_loss = self.graph_reg * torch.norm(G, p=1) # Example L1 sparsity

        # 3. Calculate scaled dot-product attention scores
        # (batch_size, n_heads, seq_len, d_k) x (batch_size, n_heads, d_k, seq_len) -> (batch_size, n_heads, seq_len, seq_len)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)

        # 4. Modulate scores with the learned causal graph G
        scores = self.modulate_with_graph(scores, G)

        # 5. Apply intervention mechanism (if specified)
        # This modifies scores based on the intervention target
        scores = self.intervention_module(scores, intervention, G)

        # 6. Apply attention mask (if provided)
        if mask is not None:
            # Ensure mask is broadcastable (e.g., (batch_size, 1, 1, seq_len) or (batch_size, 1, seq_len, seq_len))
            if mask.dtype == torch.bool:
                 scores = scores.masked_fill(mask, float('-inf'))
            else:
                 # Assuming mask values are 0 for masked, 1 for keep
                 scores = scores + (1.0 - mask) * -1e9 

        # 7. Apply softmax to get attention weights
        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # 8. Compute weighted sum of values
        # (batch_size, n_heads, seq_len, seq_len) x (batch_size, n_heads, seq_len, d_k) -> (batch_size, n_heads, seq_len, d_k)
        output = torch.matmul(attn, v)

        # 9. Concatenate heads and apply final linear layer
        # (batch_size, n_heads, seq_len, d_k) -> (batch_size, seq_len, n_heads, d_k) -> (batch_size, seq_len, d_model)
        output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.d_model)
        output = self.wo(output)

        if return_graph:
            # Return graph G along with output and attention weights
            # Note: Graph loss should be handled during the training loop
            return output, attn, G 
        else:
            return output, attn

# Example Usage (Conceptual - requires training loop)
if __name__ == '__main__':
    d_model = 64
    n_heads = 4
    seq_len = 10
    batch_size = 2

    ecam_layer = ECAM(d_model, n_heads)
    x = torch.randn(batch_size, seq_len, d_model)
    
    # Without intervention, return graph
    output, attn, G = ecam_layer(x, return_graph=True)
    print("Output shape:", output.shape) # Expected: [2, 10, 64]
    print("Attention shape:", attn.shape) # Expected: [2, 4, 10, 10]
    print("Graph shape:", G.shape) # Expected: [2, 10, 10]

    # With intervention (example: intervene on node 3)
    intervention_details = {"node_idx": 3}
    output_interv, attn_interv = ecam_layer(x, intervention=intervention_details)
    print("\nOutput shape (with intervention):", output_interv.shape)
    print("Attention shape (with intervention):", attn_interv.shape)
    
    # Verify intervention effect (column 3 in attention should be zero after softmax due to -inf scores)
    print("Attention weights to intervened node 3 (should be ~0):")
    print(attn_interv[:, :, :, 3].max()) # Max weight should be very small

