from torch_geometric.nn import GCNConv
import torch.nn.functional as F
import torch
from torch_geometric.utils import softmax
from torch_geometric.nn import MessagePassing

class MultiHeadAttentionLayer(MessagePassing):
    def __init__(self, in_channels, out_channels, num_heads = 1):
        super(MultiHeadAttentionLayer, self).__init__(aggr='add')
        self.num_heads = num_heads
        self.out_channels = out_channels
        self.attentions = torch.nn.ModuleList([
            torch.nn.Linear(in_channels + out_channels, 1) for _ in range(num_heads)
        ])
        for att in self.attentions:
            torch.nn.init.xavier_uniform_(att.weight)

    def forward(self, x, edge_index, edge_weight):
        # Compute attention coefficients for each head
        attentions = [self.compute_attention(x, edge_index, edge_weight, head) for head in range(self.num_heads)]
        edge_weights = torch.stack(attentions, dim=-1)
        return edge_weights.mean(dim=-1)

    def compute_attention(self, x, edge_index, edge_weight, head):
        # row, col = edge_index
        # x_row = x[row]
        # x_col = x[col]
        # edge_features = torch.cat([x_row, x_col], dim=1)
        # attention_scores = self.attentions[head](edge_features).squeeze()
        # attention_scores = F.leaky_relu(attention_scores, negative_slope=0.2)
        # attention_scores = softmax(attention_scores, index=row)
        # attention_scores = attention_scores * edge_weight
        row, col = edge_index
        x_row = x[row]
        x_col = x[col]
        edge_weight = edge_weight.unsqueeze(1)  # Making sure edge_weight has the correct shape

        # Concatenate node features and edge weight
        edge_features = torch.cat([x_row, x_col, edge_weight], dim=1)
        attention_scores = self.attentions[head](edge_features).squeeze()
        attention_scores = F.leaky_relu(attention_scores, negative_slope=0.2)
        attention_scores = softmax(attention_scores, index=row)
        return attention_scores

class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, use_batch_norm=True, shared_weights=False, var_distance = False, tr_dist = False, tr_att = False):
        super(GCN, self).__init__()

        self.shared_weights = shared_weights
        self.use_batch_norm = use_batch_norm
        self.num_layers = num_layers
        self.var_distance = var_distance
        # self.var_distance = False
        # print('var_dist',self.var_distance)

        # print(f' Var Distance: {self.var_distance}')
        self.tr_dist = tr_dist
        self.tr_att = tr_att
        if self.var_distance:
            print("Attention for Edge Lengths:", self.tr_att)
        if self.tr_dist: 
            if self.tr_att:
                self.attention = MultiHeadAttentionLayer(in_channels+1, in_channels)
            else:
                self.W_dist = torch.nn.Parameter(torch.FloatTensor(4, 1))   # 4 is hyperparameter
                torch.nn.init.xavier_uniform_(self.W_dist)
                
        # linear layer in the beginning for a skip connection
        self.sc_lin = torch.nn.Linear(in_channels, hidden_channels)
        # convolutional layers
        self.conv1 = GCNConv(in_channels, hidden_channels) 
        if self.shared_weights:                                          # we could share this for the second convolutional layer
            self.conv2 = GCNConv(hidden_channels, hidden_channels)
        else:
            self.conv_layers = torch.nn.ModuleList([
                GCNConv(hidden_channels, hidden_channels) for _ in range(num_layers)])
        self.conv3 = GCNConv(hidden_channels, hidden_channels)  
        # batch normalization, if chosen
        if self.use_batch_norm:
            self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
            self.bn_layers = torch.nn.ModuleList([
                torch.nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)])
            self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
        

    def forward(self, x, edge_index, edge_weight = None):
        if self.var_distance == True and self.tr_dist == False: 
            edge_weight = 1/edge_weight
            
        
        if self.tr_dist:
            # print(f"Trainable Distance: {self.tr_dist}")
            # print(edge_weight)
                
            # print(self.W_dist)
            # print(self.sc_lin)
            if self.tr_att:
                edge_weight = self.attention(x, edge_index, edge_weight) 
            else:
                weighted_tensor = torch.stack([(edge_weight**i) * torch.exp(self.W_dist[i]) for i in range(self.W_dist.size(0))])
                edge_weight = torch.sum(weighted_tensor, dim=0) + 1
                edge_weight = 1/edge_weight
            # print(edge_weight)
        if self.var_distance: 
            edge_weight = edge_weight.unsqueeze(1)
            
        sc = self.sc_lin(x)
        
        if self.var_distance: 
            # print(x.size(), edge_index.size(),edge_weight.size())
            z = self.conv1(x, edge_index, edge_weight = edge_weight) # .unsqueeze(1)
        else:
            z = self.conv1(x, edge_index)
        if self.use_batch_norm:
            z = self.bn1(z)
        z = F.relu(z)
        z = z + sc
        
        for i in range(self.num_layers):
            if self.shared_weights:
                z_prev = z.clone()
                if self.var_distance: 
                    z = self.conv2(z, edge_index, edge_weight = edge_weight)
                else: 
                    z = self.conv2(z, edge_index)
                if self.use_batch_norm:
                    z = self.bn_layers[i](z)
                z = F.relu(z)
                z = z + z_prev
            else:
                z_prev = z.clone()
                if self.var_distance: 
                    z = self.conv_layers[i](z, edge_index, edge_weight = edge_weight )
                else:
                    z = self.conv_layers[i](z, edge_index)
                if self.use_batch_norm:
                    z = self.bn_layers[i](z)
                z = F.relu(z)
                z = z + z_prev

        z_prev = z.clone()
        z = self.conv3(z, edge_index)
        if self.use_batch_norm:
            z = self.bn3(z)
        z = z + z_prev
        
        return z







# def forward(self, x, edge_index, edge_weight = None):
#         # skip connection
#         # print(x.size(), edge_index.size(), edge_weight.size())
#         # print(edge_index)
#         # print(edge_weight)
#         # if self.var_distance: 
#         #     edge_weight = 1/edge_weight
#         #     if self.tr_dist: 
#         #         # weighted_tensor = torch.stack([(edge_weight**i) * self.W_dist[i] for i in range(self.W_dist.size(0))])
#         #         weighted_tensor = torch.stack([(edge_weight**i) * torch.exp(self.W_dist[i]) for i in range(self.W_dist.size(0))])
#         #         edge_weight = torch.sum(weighted_tensor, dim=0) + 1

#         if self.var_distance == True and self.tr_dist == False: 
#             edge_weight = 1/edge_weight
#             # print(edge_weight.unsqueeze(1).size())
#             # print(x.size())
#             # print(edge_index.size())
#             # print(edge_weight)
#         if self.tr_dist: 
#             # weighted_tensor = torch.stack([(edge_weight**i) * self.W_dist[i] for i in range(self.W_dist.size(0))])
#             weighted_tensor = torch.stack([(edge_weight**i) * torch.exp(self.W_dist[i]) for i in range(self.W_dist.size(0))])
#             edge_weight = torch.sum(weighted_tensor, dim=0) + 1
#             edge_weight = 1/edge_weight
            
#         sc = self.sc_lin(x)
#         if self.var_distance: 
#             # print(edge_weight)
#             z = self.conv1(x, edge_index, edge_weight = edge_weight.unsqueeze(1)) # .unsqueeze(1)
#             # print(z.size())
#             # print(z)
#         else:
#             z = self.conv1(x, edge_index)
        

#         if self.use_batch_norm:
#             z = self.bn1(z)
#         z = F.relu(z)
#         z = z + sc
        
#         for i in range(self.num_layers):
#             if self.shared_weights:
#                 z_prev = z.clone()
#                 if self.var_distance: 
#                     z = self.conv2(z, edge_index, edge_weight = edge_weight.unsqueeze(1))
#                 else: 
#                     z = self.conv2(z, edge_index)
#                 if self.use_batch_norm:
#                     z = self.bn_layers[i](z)
#                 z = F.relu(z)
#                 z = z + z_prev
#             else:
#                 z_prev = z.clone()
#                 if self.var_distance: 
#                     z = self.conv_layers[i](z, edge_index, edge_weight = edge_weight.unsqueeze(1) )
#                 else:
#                     z = self.conv_layers[i](z, edge_index)
#                 if self.use_batch_norm:
#                     z = self.bn_layers[i](z)
#                 z = F.relu(z)
#                 z = z + z_prev

#         z_prev = z.clone()
#         z = self.conv3(z, edge_index)
#         if self.use_batch_norm:
#             z = self.bn3(z)
#         z = z + z_prev
        
#         return z

