import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing, global_mean_pool
from torch.nn import Linear, ReLU, BatchNorm1d
import torch.nn as nn
from torch_geometric.transforms import BaseTransform
from torch_geometric.utils import get_laplacian, to_dense_adj, degree
import torch_geometric

class SuperpixelsMPNN(MessagePassing):
    def __init__(self, in_channels, hidden_channels, k=3):
        super(SuperpixelsMPNN, self).__init__(aggr='sum')
        
        self.mlp1 = nn.Sequential(
            Linear(in_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        self.mlp2 = nn.Sequential(
            Linear(hidden_channels * 2 + hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )
        
        self.edge_transform = nn.Sequential(
            Linear(k, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels)
        )

    def forward(self, x, edge_index, edge_attr=None):
        x = self.mlp1(x)
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        edge_features = self.edge_transform(edge_attr)
        return torch.cat([x_i, x_j, edge_features], dim=1)

    def update(self, aggr_out):
        return self.mlp2(aggr_out)

class LaplacianPE(BaseTransform):
    def __init__(self, k=8, with_pos_enc=True, with_proj=False, with_virtual=False):
        self.k = k
        self.with_pos_enc = with_pos_enc
        self.with_proj = with_proj
        self.with_virtual = with_virtual

    def __call__(self, data):
        original_num_nodes = data.num_nodes
        edge_index, edge_weight = get_laplacian(
            data.edge_index, normalization='sym',
            num_nodes=original_num_nodes
        )
        
        adj = to_dense_adj(edge_index, edge_attr=edge_weight)[0]
        
        try:
            eigvals, eigvecs = torch.linalg.eigh(adj)
            
            # Sort eigenvalues and eigenvectors
            idx = torch.argsort(eigvals)
            eigvecs = eigvecs[:, idx]
            
            # Ensure we have enough eigenvectors
            if eigvecs.size(1) < self.k + 1:
                # Pad with zeros if we don't have enough eigenvectors
                padding = torch.zeros((eigvecs.size(0), self.k + 1 - eigvecs.size(1)), 
                                   device=eigvecs.device)
                eigvecs = torch.cat([eigvecs, padding], dim=1)
            
            # Get top k eigenvectors (excluding first one as it's constant)
            eigvecs_k = eigvecs[:, 1:self.k+1]
            
        except RuntimeError:
            print(f"Warning: Eigendecomposition failed for a graph with {original_num_nodes} nodes. Using zeros.")
            eigvecs_k = torch.zeros((original_num_nodes, self.k))

        if self.with_virtual:
            # Create virtual node positional encoding as mean of neighbors
            virtual_pos_enc = eigvecs_k.mean(dim=0, keepdim=True)
            eigvecs_k = torch.cat([eigvecs_k, virtual_pos_enc], dim=0)
            
            # Create virtual node connections
            virtual_node_idx = original_num_nodes
            src_to_virtual = torch.arange(original_num_nodes, device=data.edge_index.device)
            virtual_to_src = torch.full((original_num_nodes,), virtual_node_idx, 
                                      device=data.edge_index.device)
            
            virtual_edges = torch.stack([
                torch.cat([virtual_to_src, src_to_virtual]),
                torch.cat([src_to_virtual, virtual_to_src])
            ], dim=0)
            
            data.edge_index = torch.cat([data.edge_index, virtual_edges], dim=1)
            
            # Handle virtual node features
            if hasattr(data, 'x') and data.x is not None:
                virtual_feature = data.x.mean(dim=0, keepdim=True)
                data.x = torch.cat([data.x, virtual_feature], dim=0)
            else:
                deg = degree(data.edge_index[0], original_num_nodes + 1)
                data.x = deg.view(-1, 1)
            
            data.num_nodes = original_num_nodes + 1

        if self.with_pos_enc:
            data.pos_enc = eigvecs_k
        else:
            data.pos_enc = torch.zeros_like(eigvecs_k)
        
        if self.with_proj:
            # Compute edge attributes as element-wise product
            edge_attr = eigvecs_k[data.edge_index[0]] * eigvecs_k[data.edge_index[1]]
            data.edge_attr = edge_attr
        else:
            edge_attr = torch.zeros((data.edge_index.size(1), self.k), 
                                  device=data.edge_index.device)
            data.edge_attr = edge_attr

        return data
        
class NetWithPE(torch.nn.Module):
    def __init__(self, num_features=1, hidden_channels=64, pos_enc_dim=8, num_classes=10):
        super(NetWithPE, self).__init__()
        
        self.num_features = num_features
        self.pos_enc_dim = pos_enc_dim
        input_dim = num_features + pos_enc_dim
        

        self.conv1 = SuperpixelsMPNN(input_dim, hidden_channels, k=pos_enc_dim)
        self.conv2 = SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
        self.conv3 = SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
        
        self.classifier = nn.Sequential(
            Linear(hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels),
            Linear(hidden_channels, num_classes)
        )

    def forward(self, data):
        if not hasattr(data, 'x') or data.x is None:
            data.x = torch.ones((data.num_nodes, self.num_features), device=data.pos_enc.device)
        
        if data.x.dim() == 1:
            data.x = data.x.unsqueeze(1)
        

        x = torch.cat([data.x, data.pos_enc], dim=-1)
        
        edge_index, edge_attr = data.edge_index, data.edge_attr
        batch = data.batch
        
        x = self.conv1(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        
        x = self.conv2(x, edge_index, edge_attr)
        x = F.relu(x)
        x = F.dropout(x, p=0.2, training=self.training)
        
        x = self.conv3(x, edge_index, edge_attr)
        x = global_mean_pool(x, batch)
        x = self.classifier(x)
        
        return F.log_softmax(x, dim=1)

class NetWithPEandMP(torch.nn.Module):
    def __init__(self, num_features=1, hidden_channels=64, pos_enc_dim=8, 
                 num_classes=10, initial_iters=0, spectral_iters=3):
        super(NetWithPEandMP, self).__init__()
        
        input_dim = num_features + pos_enc_dim
        self.initial_iters = initial_iters
        self.spectral_iters = spectral_iters
        
        self.initial_mp = nn.ModuleList([
            SuperpixelsMPNN(input_dim if i == 0 else hidden_channels, 
                           hidden_channels, k=pos_enc_dim)
            for i in range(initial_iters)
        ])
        
        self.spectral_mp = nn.ModuleList([
            SuperpixelsMPNN(hidden_channels, hidden_channels, k=pos_enc_dim)
            for _ in range(spectral_iters)
        ])
        
        self.classifier = nn.Sequential(
            Linear(hidden_channels, hidden_channels),
            ReLU(),
            BatchNorm1d(hidden_channels),
            Linear(hidden_channels, num_classes)
        )

    def forward(self, data):
        if not hasattr(data, 'x') or data.x is None:
            data.x = torch.ones((data.num_nodes, 1), device=data.pos_enc.device)
            
        x = torch.cat([data.x, data.pos_enc], dim=-1)
        edge_index, edge_attr = data.edge_index, data.edge_attr
        batch = data.batch
        
        # Initial message passing layers
        for mp in self.initial_mp:
            x = mp(x, edge_index, edge_attr)
            x = F.relu(x)
            x = F.dropout(x, p=0.2, training=self.training)
        
        # Spectral message passing layers
        for mp in self.spectral_mp:
            x = mp(x, edge_index, edge_attr)
            x = F.relu(x)
            x = F.dropout(x, p=0.2, training=self.training)
        
        x = global_mean_pool(x, batch)
        x = self.classifier(x)
        
        return F.log_softmax(x, dim=1)