import torch
from torch_geometric.nn import GCNConv, GATConv, GINConv
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_undirected


class TemporalGNN(torch.nn.Module):
    def __init__(self, embed_size, seq_length, batch_size):
        super(TemporalGNN, self).__init__()
        self.embed_size = embed_size
        self.conv = GCNConv(in_channels=1, out_channels=self.embed_size)
        self.gru = nn.GRU(self.embed_size*14, 32,2, batch_first=True)
        self.dp = nn.Dropout(0.5)
        self.linear1 = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU())
        self.linear2 = nn.Sequential(
            nn.Linear(32, 32),
            nn.ReLU())
        self.fc1 = nn.Linear(32, 1)
        edge_index = torch.tensor(
            [[1, 2],
             [1, 3],
             [1, 4],
             [1, 5],
             [1, 9],
             [1, 12],
             [2, 4],
             [2, 7],
             [2, 8],
             [2, 13],
             [3, 6],
             [3, 10],
             [3, 13],
             [3, 14],
             [4, 7],
             [4, 8],
             [5, 9],
             [5, 11],
             [6, 10],
             [7, 8],
             [8, 13],
             [9, 11]], dtype=torch.long).T

        device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        edge_index = to_undirected(edge_index)-1
        self.edge_index = edge_index.to(device)
        self.seq_length = seq_length

    def forward(self, x):
        """
        x = Node features for T time steps
        edge_index = Graph edge indices
        """

        x = x.permute(0, 2, 1).unsqueeze(2)
        print(x.size())
        tensor_list = [self.conv(x[:,:,:,i], self.edge_index) for i in range(self.seq_length)]

        h = torch.stack(tensor_list, dim=-1)
        print(h.size())
        h = h.view(h.size(0), -1, h.size(-1))
        h = torch.transpose(h, 1, 2)
        print(h.size())
        out, _ = self.gru(h)
        out = out[:, -1, :]
        h = self.dp(out)
        h = self.linear1(h)
        h = self.linear2(h)
        h = self.fc1(h)
        return h


# model = TemporalGNN(embed_size=128, seq_length=30, batch_size=128)
# from torchsummary import summary
# summary(model)