import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import global_max_pool, global_add_pool, global_mean_pool, AttentionalAggregation, JumpingKnowledge


class ConvNet(nn.Module):
    def __init__(self, embed_size, hidden_size, seq_length=18):
        super(ConvNet, self).__init__()
        self.embed_size = embed_size
        self.hidden_size = hidden_size
        self.layer1 = nn.Sequential(
            nn.Conv1d(14, embed_size, kernel_size=1, stride=1),
            #             nn.BatchNorm2d(30),
            nn.ReLU())
        #             nn.MaxPool2d(kernel_size=2, stride=2))
        self.layer2 = nn.Sequential(
            nn.Conv1d(embed_size, embed_size, kernel_size=1, stride=1),
            nn.ReLU())
        self.dp = nn.Dropout(0.5)
        self.relu1 = nn.Sequential(
            nn.Linear(embed_size * seq_length, hidden_size),
            nn.ReLU())
        self.relu2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU())
        self.fc1 = nn.Linear(hidden_size, 1)

    def forward(self, x):  # x: (128,30,14)

        # x, batch = data.x, data.batch
        # dense_x, mask = to_dense_batch(x, batch)  # dense_x: (128, 420, 128), mask: (128, 420, 128)

        x = x.permute(0, 2, 1)

        out = self.layer1(x)
        out = self.layer2(out)
        out = out.view(out.size(0), -1)
        # out = self.dp(out)
        out = self.relu1(out)
        out = self.relu2(out)

        out = self.fc1(out)
        return out.unsqueeze(-1)

# model = ConvNet(embed_size=128, hidden_size=128, seq_length=14)
# from torchsummary import summary
# summary(model)
