import torch
import torch.nn.functional as F
from torch.nn import Linear
from torch_geometric.nn import GraphUNet, global_mean_pool
from torch_geometric.utils import dropout_edge

class GraphClassificationUNet(torch.nn.Module):
    def __init__(self, dataset, hidden=32, depth=3, pool_ratios=None):
        super().__init__()
        if pool_ratios is None:
            pool_ratios = [0.5, 0.5]
        self.unet = GraphUNet(dataset.num_features, hidden, hidden, depth=depth, pool_ratios=pool_ratios)
        self.lin = Linear(hidden, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        edge_index, _ = dropout_edge(edge_index, p=0.2,
                                     force_undirected=True,
                                     training=self.training)
        x = F.dropout(x, p=0.92, training=self.training)
        x = self.unet(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.lin(x)
        return x
