import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch.nn import Linear, ReLU, BatchNorm1d
import torch.nn as nn
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import get_laplacian, to_dense_adj, dense_to_sparse


class SuperpixelsMPNN(MessagePassing):
    def __init__(self, in_channels, hidden_channels, k=3):
        super(SuperpixelsMPNN, self).__init__(aggr='sum')
        
        self.mlp1 = nn.Sequential(
            Linear(in_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        # Modified MLP2 to handle edge features
        self.mlp2 = nn.Sequential(
            Linear(hidden_channels * 2 , hidden_channels),  # +hidden_channels for edge features
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        # Edge feature transformation
        self.edge_transform = nn.Sequential(
            Linear(k, hidden_channels),  # Transform each eigenprojector component
            ReLU(),
            BatchNorm1d(hidden_channels)
        )

    def forward(self, x, edge_index, edge_attr=None):
        x = self.mlp1(x)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        # Transform edge features
        #edge_features = self.edge_transform(edge_attr)
        
        # Concatenate node features and edge features
        return torch.cat([x_i, x_j], dim=1)

    def update(self, aggr_out):
        return self.mlp2(aggr_out)

class InitialMPNN(MessagePassing):
    def __init__(self, hidden_channels):
        super(InitialMPNN, self).__init__(aggr='sum')
        
        self.mlp = nn.Sequential(
            Linear(hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )

    def forward(self, x, edge_index):
        return self.propagate(edge_index, x=x)

    def message(self, x_j):
        return self.mlp(x_j)

class SpectralMPNN(MessagePassing):
    def __init__(self, hidden_channels, pe_dim):
        super(SpectralMPNN, self).__init__(aggr='sum')
        
        # MLP for processing edge features (spectral projectors)
        self.edge_mlp = nn.Sequential(
            Linear(pe_dim, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        # MLP for combining node and processed edge features
        self.combine_mlp = nn.Sequential(
            Linear(hidden_channels * 3, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        self.mlp1 = nn.Sequential(
            Linear(hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.mlp1(x)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        # Combine node features with processed edge features
        #Process edge features
        edge_features = self.edge_mlp(edge_attr)
        return self.combine_mlp(torch.cat([x_i, x_j, edge_features], dim=1))

class NetWithPEandMP(torch.nn.Module):
    def __init__(self, num_features=1, pos_enc_dim=8, num_classes=10, hidden_channels=64, 
                 initial_iters=2, spectral_iters=2):
        super(NetWithPEandMP, self).__init__()
        
        self.initial_iters = initial_iters
        self.spectral_iters = spectral_iters
        self.pos_enc_dim = pos_enc_dim
        
        # Initial message passing layers (with zero initialization)
        self.initial_conv = nn.ModuleList([
            InitialMPNN(hidden_channels) for _ in range(initial_iters)
        ])
        
        # Feature projection for zero-initialized features
        self.init_proj = nn.Sequential(
            Linear(pos_enc_dim + 1, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        # Spectral message passing layers
        self.spectral_conv = nn.ModuleList([
            SpectralMPNN(hidden_channels, pos_enc_dim) for _ in range(spectral_iters)
        ])
        
        # Standard message passing layers
        self.conv1 = SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
        self.conv2 = SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
        self.conv3 = SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
        
        self.classifier = nn.Sequential(
            Linear(hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels),
            Linear(hidden_channels, num_classes)
        )

    def create_fully_connected_edges(self, num_nodes, batch, data):
        """Create fully connected edge indices for each graph in the batch"""
        edge_indices = []
        edge_ids = []  # Track which graph each edge belongs to
        
        batch_size = batch.max().item() + 1
        current_offset = 0
        
        for i in range(batch_size):
            mask = batch == i
            n = mask.sum()
            
            # Create fully connected edges for this graph
            rows, cols = torch.meshgrid(torch.arange(n), torch.arange(n))
            edges = torch.stack([rows.flatten(), cols.flatten()], dim=0).to(data.x.device)
            
            # Remove self-loops
            mask = edges[0] != edges[1]
            edges = edges[:, mask]
            
            # Adjust indices based on cumulative offset
            edges = edges + current_offset
            
            edge_indices.append(edges)
            edge_ids.append(torch.full((edges.size(1),), i, dtype=torch.long))
            
            current_offset += n
        
        # Concatenate all edge indices and graph ids
        edge_index = torch.cat(edge_indices, dim=1).to(batch.device)
        edge_batch = torch.cat(edge_ids, dim=0).to(batch.device)
        
        return edge_index, edge_batch

    def compute_spectral_edge_features(self, edge_index, pos_enc):
        """
        Compute edge features from positional encodings of connected nodes
        pos_enc: [num_nodes, pos_enc_dim] node positional encodings
        returns: [num_edges, pos_enc_dim] edge features
        """
        row, col = edge_index
        # Create edge features by element-wise multiplication of node encodings
        edge_features = pos_enc[row] * pos_enc[col]
        return edge_features

    def forward(self, data):
        x, edge_index, batch = torch.cat([data.x, data.pos_enc], dim=-1), data.edge_index, data.batch
        pos_enc = data.pos_enc  # [num_nodes, pos_enc_dim]
        
        # Stage 1: Initial message passing with zero initialization
        #x = torch.zeros_like(x)
        x = self.init_proj(x)
        
        for conv in self.initial_conv:
            x = conv(x, edge_index)
            x = F.relu(x)
            x = F.dropout(x, p=0.2, training=self.training)
        
        # Stage 2: Spectral message passing on fully connected graph
        # Create fully connected edges for each graph in the batch
        fc_edge_index, fc_edge_batch = self.create_fully_connected_edges(data.num_nodes, batch, data)
        
        # Compute spectral edge features for fully connected graph
        spectral_edge_features = self.compute_spectral_edge_features(fc_edge_index, pos_enc)
        y=x.clone()
        for conv in self.spectral_conv:
            y = conv(y, fc_edge_index, spectral_edge_features)
            y = F.relu(y)
            x = F.dropout(y, p=0.2, training=self.training)
        
        # Stage 3: Standard message passing on original graph
        #x = self.conv1(y, edge_index)
        #x = F.relu(x)
        #x = F.dropout(x, p=0.2, training=self.training)
        
        #x = self.conv2(x, edge_index)
        #x = F.relu(x)
        #x = F.dropout(x, p=0.2, training=self.training)
        
        #x = self.conv3(x, edge_index) + y
        x = global_mean_pool(x, batch)
        x = self.classifier(x)
        
        return F.log_softmax(x, dim=1)