#!/usr/bin/env python
# coding: utf-8

import torch
import torch.nn as nn
import torch.nn.functional as F

class GraphLearner(nn.Module):
    """Learns a Local Causal Graph (LCG) from input features."""
    def __init__(self, d_model, graph_reg, num_nodes=None, embedding_dim=None, alpha=0.05):
        """Initialize the GraphLearner.

        Args:
            d_model (int): Dimension of the input features (per node/token).
            graph_reg (float): Regularization strength (e.g., for sparsity or acyclicity).
                               Not directly used in this basic version, but kept for compatibility.
            num_nodes (int, optional): Max number of nodes (e.g., sequence length). 
                                       Can be inferred dynamically, but sometimes useful.
            embedding_dim (int, optional): Intermediate embedding dimension for graph learning.
                                           Defaults to d_model // 4 if not provided.
            alpha (float): Threshold for potential edge existence (used in some methods, placeholder here).
        """
        super().__init__()
        self.d_model = d_model
        self.graph_reg = graph_reg # Store for potential future use
        self.alpha = alpha # Store for potential future use
        
        if embedding_dim is None:
            embedding_dim = max(16, d_model // 4) # Use a smaller dimension for graph learning
        self.embedding_dim = embedding_dim

        # Simple MLP-based approach to learn edge weights/probabilities
        # Project input features to a lower dimension for graph learning
        self.node_embedding_src = nn.Linear(d_model, embedding_dim)
        self.node_embedding_dst = nn.Linear(d_model, embedding_dim)

        # Layer to compute edge scores (e.g., using dot product or MLP)
        # Using a simple bilinear scoring function here
        self.weight_tensor = nn.Parameter(torch.Tensor(embedding_dim, embedding_dim))
        nn.init.xavier_uniform_(self.weight_tensor)

        # Optional: Add bias terms or use a more complex MLP scorer
        # self.edge_scorer = nn.Sequential(
        #     nn.Linear(embedding_dim * 2, embedding_dim),
        #     nn.ReLU(),
        #     nn.Linear(embedding_dim, 1)
        # )

    def forward(self, x):
        """Compute the adjacency matrix of the LCG.

        Args:
            x (torch.Tensor): Input tensor (batch_size, seq_len, d_model).

        Returns:
            torch.Tensor: Learned adjacency matrix G (batch_size, seq_len, seq_len).
                          Values typically represent edge strengths or probabilities (e.g., after sigmoid).
        """
        batch_size, seq_len, _ = x.shape

        # Project features for source and destination nodes
        # (batch_size, seq_len, embedding_dim)
        x_src = self.node_embedding_src(x)
        x_dst = self.node_embedding_dst(x)

        # Compute edge scores using bilinear form
        # (batch_size, seq_len, embedding_dim) x (embedding_dim, embedding_dim) -> (batch_size, seq_len, embedding_dim)
        left_proj = torch.matmul(x_src, self.weight_tensor)
        # (batch_size, seq_len, embedding_dim) x (batch_size, embedding_dim, seq_len) -> (batch_size, seq_len, seq_len)
        adj_scores = torch.matmul(left_proj, x_dst.transpose(-2, -1))

        # Alternative: Concatenate and use MLP scorer
        # x_src_rep = x_src.unsqueeze(2).repeat(1, 1, seq_len, 1)
        # x_dst_rep = x_dst.unsqueeze(1).repeat(1, seq_len, 1, 1)
        # combined = torch.cat([x_src_rep, x_dst_rep], dim=-1)
        # adj_scores = self.edge_scorer(combined).squeeze(-1)

        # Apply sigmoid to get values between 0 and 1 (interpretable as probabilities or weights)
        adj_matrix = torch.sigmoid(adj_scores)

        # Ensure graph is directed (optional, depends on interpretation)
        # Forcing acyclicity is more complex (e.g., NOTEARS constraint)
        # For now, we just return the learned matrix. Acyclicity might be enforced
        # during training via regularization or specific loss terms (like in NOTEARS).
        
        # Mask out self-loops (optional, often desired)
        identity = torch.eye(seq_len, device=x.device, dtype=torch.bool).unsqueeze(0)
        adj_matrix = adj_matrix.masked_fill(identity, 0) 

        # Note: The paper mentions LCGs (Local Causal Graphs). This implementation learns
        # a dense graph. Sparsity might be encouraged via regularization (e.g., L1 on adj_matrix)
        # or thresholding, potentially using self.alpha.
        # graph_loss = self.graph_reg * torch.norm(adj_matrix, p=1) # Example L1 sparsity loss
        # Add graph_loss to the main model's loss during training.

        return adj_matrix

# Example Usage
if __name__ == '__main__':
    d_model = 64
    seq_len = 10
    batch_size = 2
    graph_reg = 0.01

    graph_learner = GraphLearner(d_model, graph_reg)
    x = torch.randn(batch_size, seq_len, d_model)
    
    adj_matrix = graph_learner(x)
    print("Learned Adjacency Matrix shape:", adj_matrix.shape) # Expected: [2, 10, 10]
    print("Sample Adjacency Matrix (first batch):")
    print(adj_matrix[0])
    # Check self-loops are zero
    print("Diagonal elements:", torch.diag(adj_matrix[0]))

