from torch_geometric.nn import GCNConv, GATConv
import torch.nn.functional as F
import torch


class GCN_GAT(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, num_layers, use_batch_norm=True, shared_weights=True, var_distance = False, tr_dist = False):
        super(GCN_GAT, self).__init__()

        self.shared_weights = shared_weights
        self.use_batch_norm = use_batch_norm
        self.num_layers = num_layers
        self.var_distance = var_distance
        print( f'var_distance: {self.var_distance}')
        self.tr_dist = tr_dist
        print(f'tr_dist: {tr_dist}')
        if self.tr_dist: 
            self.W_dist = torch.nn.Linear(4, 1) # 4 is a hyperparameter
            # self.W_dist = torch.nn.Parameter(torch.FloatTensor(4, 1))   
            # torch.nn.init.xavier_uniform_(self.W_dist) # keep positive, clip them model.clamp(), exponentiate
    
        # linear layer in the beginning for a skip connection
        self.sc_lin = torch.nn.Linear(in_channels, hidden_channels)
        # convolutional layers
        self.conv1 = GATConv(in_channels, hidden_channels, edge_dim=1) 
        if self.shared_weights:                                          # we could share this for the second convolutional layer
            self.conv2 = GATConv(hidden_channels, hidden_channels, edge_dim=1)
        else:
            self.conv_layers = torch.nn.ModuleList([
                GATConv(hidden_channels, hidden_channels, edge_dim=1) for _ in range(num_layers)])
        self.conv3 = GATConv(hidden_channels, hidden_channels, edge_dim=1)  
        # batch normalization, if chosen
        if self.use_batch_norm:
            self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
            self.bn_layers = torch.nn.ModuleList([
                torch.nn.BatchNorm1d(hidden_channels) for _ in range(num_layers)])
            self.bn3 = torch.nn.BatchNorm1d(hidden_channels)
        

    def forward(self, x, edge_index, edge_weight = None):
        # skip connection
        # print(x.size(), edge_index.size(), edge_weight.size())
        # print(edge_index)
        if self.var_distance == True and self.tr_dist == False: 
            edge_weight = 1/edge_weight
            # print(edge_weight.unsqueeze(1).size())
            # print(x.size())
            # print(edge_index.size())
            # print(edge_weight)
        if self.tr_dist: 
            # weighted_tensor = torch.stack([(edge_weight**i) * self.W_dist[i] for i in range(self.W_dist.size(0))])
            weighted_tensor = torch.stack([(edge_weight**i) * torch.exp(self.W_dist[i]) for i in range(self.W_dist.size(0))])
            edge_weight = torch.sum(weighted_tensor, dim=0) + 1
            edge_weight = 1/edge_weight
            print(self.W_dist)
            print(edge_weight.size())
            print(edge_weight)
            
        sc = self.sc_lin(x)
        if self.var_distance: 
            # print(edge_weight)
            z = self.conv1(x, edge_index, edge_attr = edge_weight.unsqueeze(1))
            # z = self.conv1(x, edge_index) # unsqueeze(1)
            # print(z.size())
            # print(z)
        else:
            z = self.conv1(x, edge_index)
            # print(z.size())
            # print(z)

        if self.use_batch_norm:
            z = self.bn1(z)

        z = F.relu(z)
        z = z + sc
        for i in range(self.num_layers):
            if self.shared_weights:
                z_prev = z.clone()
                if self.var_distance: 
                    z = self.conv2(z, edge_index, edge_attr = edge_weight.unsqueeze(1))
                else: 
                    z = self.conv2(z, edge_index)
                if self.use_batch_norm:
                    z = self.bn_layers[i](z)
                z = F.relu(z)
                z = z + z_prev
            else:
                z_prev = z.clone()
                if self.var_distance: 
                    z = self.conv_layers[i](z, edge_index, edge_attr = edge_weight.unsqueeze(1) )
                else:
                    z = self.conv_layers[i](z, edge_index)
                if self.use_batch_norm:
                    z = self.bn_layers[i](z)
                z = F.relu(z)
                z = z + z_prev

        z_prev = z.clone()
        z = self.conv3(z, edge_index)
        if self.use_batch_norm:
            z = self.bn3(z)
        z = z + z_prev
        
        return z


