from curses import def_shell_mode
import torch
import torch.nn.functional as F
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import softmax

class MultiHeadAttentionLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, num_heads=1):
        super(MultiHeadAttentionLayer, self).__init__()
        self.num_heads = num_heads
        self.out_channels = out_channels
        self.attentions = torch.nn.ModuleList([
            torch.nn.Linear(in_channels + out_channels, 1) for _ in range(num_heads)
        ])
        for att in self.attentions:
            torch.nn.init.xavier_uniform_(att.weight)

    def forward(self, x, adj):
        attentions = [self.compute_attention(x, adj, head) for head in range(self.num_heads)]
        att_weights = torch.stack(attentions, dim=-1).mean(dim=-1)
        return att_weights

    def compute_attention(self, x, adj, head):
        N = x.size(0)
        row, col = adj.nonzero(as_tuple=True)
        x_row = x[row]
        x_col = x[col]
        edge_features = torch.cat([x_row, x_col], dim=1)
        attention_scores = self.attentions[head](edge_features).squeeze()
        attention_scores = F.leaky_relu(attention_scores, negative_slope=0.2)
        attention_scores = softmax(attention_scores, index=row)
        adj_att = torch.zeros_like(adj, dtype=torch.float)
        adj_att[row, col] = attention_scores
        return adj_att





class GraphConvolution(torch.nn.Module):
    def __init__(self, out_features):
        super(GraphConvolution, self).__init__()
        self.W_in = torch.nn.Linear(out_features, out_features, bias = False)
        self.W_out = torch.nn.Linear(out_features, out_features, bias = False)
        self.W_self = torch.nn.Linear(out_features, out_features)
        


    def forward(self, x, adj):
        adj_transpose = torch.transpose(adj, -1, -2)
        x_conv = (self.W_in(torch.matmul(adj, x))
                + self.W_out(torch.matmul(adj_transpose, x))
                + self.W_self(x))
        return x_conv
    
  
class GCN_DIR(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, num_nodes, use_layer_norm=False, use_batch_norm = True, shared_weights=True, tr_dist = False, tr_att = False):
        super(GCN_DIR, self).__init__()
        
        self.shared_weights = shared_weights
        self.use_layer_norm = use_layer_norm
        self.use_batch_norm = use_batch_norm
        if self.use_layer_norm == True and self.use_batch_norm == True: 
            raise ValueError("Unfeasible configuration")
        self.num_layers = num_layers
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        
        print("Attention for Edge Lengths:", self.tr_att)
        if self.tr_dist: 
            if self.tr_att:
                self.attention = MultiHeadAttentionLayer(in_channels, in_channels)
            else:
                self.W_dist = torch.nn.Parameter(torch.FloatTensor(4, 1))   # 4 is hyperparameter
                torch.nn.init.xavier_uniform_(self.W_dist)
            
    
        self.sc_lin = torch.nn.Linear(in_channels, hidden_channels)

        self.conv_init = GraphConvolution(hidden_channels)
        if self.shared_weights:
            self.conv_mid = GraphConvolution(hidden_channels)
        else:
            self.conv_layers = torch.nn.ModuleList([
                GraphConvolution(hidden_channels) for _ in range(num_layers)])
        self.conv_final = GraphConvolution(hidden_channels)

        if self.use_layer_norm:
            self.ln_init = torch.nn.LayerNorm([num_nodes, hidden_channels])
            self.ln_layers = torch.nn.ModuleList([
                torch.nn.LayerNorm([num_nodes, hidden_channels]) for _ in range(num_layers)])
            self.ln_final = torch.nn.LayerNorm([num_nodes, hidden_channels])
        
        if self.use_batch_norm:
            self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
            self.bn_layers = torch.nn.ModuleList([
                torch.nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)])
            self.bn3 = torch.nn.BatchNorm1d(hidden_channels)

    def forward(self, x, adj):
        
        if self.tr_dist:
            if self.tr_att:
                adj = self.attention(x, adj)
            else: 
                weighted_tensor = torch.stack([(adj ** i) * torch.exp(self.W_dist[i]) for i in range(self.W_dist.size(0))])
                edge_weight = torch.sum(weighted_tensor, dim=0) + 1
                adj = 1 / edge_weight

    
        sc = self.sc_lin(x)
        z = self.conv_init(sc, adj)
        
        if self.use_layer_norm:
            z = self.ln_init(z)
        if self.use_batch_norm:
            z_reshaped = z.view(-1, z.size(-1))
            bn_output = self.bn1(z_reshaped)
            z = bn_output.view(z.size())
        z = F.relu(z)
        z = z + sc

        # z = F.relu(sc)
        for i in range(self.num_layers):
            if self.shared_weights:
                z_prev = z.clone()
                z = self.conv_mid(z, adj)
                if self.use_layer_norm:
                    z = self.ln_layers[i](z)
                if self.use_batch_norm:
                    z_reshaped = z.view(-1, z.size(-1))
                    bn_output = self.bn_layers[i](z_reshaped)
                    z = bn_output.view(z.size())
                z = F.relu(z)
                z = z + z_prev
            else:
                z_prev = z.clone()
                z = self.conv_layers[i](z, adj)

                if self.use_layer_norm:
                    z = self.ln_layers[i](z)

                if self.use_batch_norm:
                    z_reshaped = z.view(-1, z.size(-1))
                    bn_output = self.bn_layers[i](z_reshaped)
                    z = bn_output.view(z.size())

                z = F.relu(z)
                z = z + z_prev

        z_prev = z.clone()
        z = self.conv_final(z, adj)

        if self.use_layer_norm:
            z = self.ln_final(z)
            
        if self.use_batch_norm:
            z_reshaped = z.view(-1, z.size(-1))
            bn_output = self.bn3(z_reshaped)
            z = bn_output.view(z.size())
        z = z + z_prev
        return z



    # print(f"Trainable Distance: {self.tr_dist}")
            # print("Initial adj:", adj)
            # print("W_dist before update:", self.W_dist)
            # print("sc_lin:", self.sc_lin.weight[:10, :10])
# class GraphConvolution(torch.nn.Module):
#     def __init__(self, in_features, out_features, num_nodes):
#         super(GraphConvolution, self).__init__()
#         self.init_map = torch.nn.Linear(in_features, out_features)
#         self.W_in = torch.nn.Linear(out_features, out_features)
#         self.W_out = torch.nn.Linear(out_features, out_features)
#         self.W_self = torch.nn.Linear(out_features, out_features)
#         self.norm = torch.nn.LayerNorm([num_nodes, out_features])


#     def forward(self, x, adj):
#         x = self.init_map(x)
#         adj_transpose = torch.transpose(adj, -1, -2)
#         x_conv = (self.W_in(torch.matmul(adj, x))
#                 + self.W_out(torch.matmul(adj_transpose, x))
#                 + self.W_self(x))
    
#         fin_output = self.norm(F.relu(x_conv) + x)        
#         return fin_output
    
  
# class GCN_DIR(torch.nn.Module):
#     def __init__(self, in_channels, hidden_channels, num_layers, num_nodes, use_layer_norm=True, shared_weights=True):
#         super(GCN_DIR, self).__init__()

#         self.shared_weights = shared_weights
#         self.use_layer_norm = use_layer_norm
#         self.num_layers = num_layers
        

#         self.sc_lin = torch.nn.Linear(in_channels, hidden_channels)

#         self.conv_init = GraphConvolution(in_channels, hidden_channels, num_nodes)
#         if self.shared_weights:
#             self.conv_mid = GraphConvolution(hidden_channels, hidden_channels, num_nodes)
#         else:
#             self.conv_layers = torch.nn.ModuleList([
#                 GraphConvolution(hidden_channels, hidden_channels, num_nodes) for _ in range(num_layers)])
#         self.conv_final = GraphConvolution(hidden_channels, hidden_channels, num_nodes)

#         if self.use_layer_norm:
#             self.ln_init = torch.nn.LayerNorm([num_nodes, hidden_channels])
#             self.ln_layers = torch.nn.ModuleList([
#                 torch.nn.LayerNorm([num_nodes, hidden_channels]) for _ in range(num_layers)])
#             self.ln_final = torch.nn.LayerNorm([num_nodes, hidden_channels])

#     def forward(self, x, adj):
    
#         sc = self.sc_lin(x)
#         z = self.conv_init(x, adj)
#         if self.use_layer_norm:
#             z = self.ln_init(z)
#         z = F.relu(z)
#         z = z + sc

#         # z = F.relu(sc)
#         for i in range(self.num_layers):
#             if self.shared_weights:
#                 z_prev = z.clone()
#                 z = self.conv_mid(z, adj)
#                 if self.use_layer_norm:
#                     z = self.ln_layers[i](z)
#                 z = F.relu(z)
#                 z = z + z_prev
#             else:
#                 z_prev = z.clone()
#                 z = self.conv_layers[i](z, adj)
#                 if self.use_layer_norm:
#                     z = self.ln_layers[i](z)
#                 z = F.relu(z)
#                 z = z + z_prev

#         z_prev = z.clone()
#         z = self.conv_final(z, adj)
#         if self.use_layer_norm:
#             z = self.ln_final(z)
#         z = z + z_prev
#         return z
