import torch
import torch.nn as nn
from torch_geometric.nn import MessagePassing, global_max_pool
from torch_geometric.data import Data

from utils.train_utils import weight_init
from utils.layers import MLPLayer

class SubGraphLayer(MessagePassing):

    def __init__(self, input_dim):
        super(SubGraphLayer, self).__init__(aggr='max')
        self.encoder = nn.Linear(input_dim, 1)

    
    def forward(self, x, edge_index):
        x = self.encoder(x)
        return self.propagate(edge_index, x=x)
    

    def message(self, x_j):
        return x_j 


    def update(self, aggr_out, x):
        return torch.cat([x, aggr_out], dim=-1)

class SubGraph(nn.Module):

    def __init__(self):
        super(SubGraph, self).__init__()
        self.subgraph_layers = nn.ModuleList()
        for i in range(2):
            if i == 0:
                self.subgraph_layers.append(SubGraphLayer(1))
            else:
                self.subgraph_layers.append(SubGraphLayer(2))
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        
        for layer in self.subgraph_layers:
            x = layer(x, edge_index)

        return x


data = Data(x=torch.tensor([[1.0], [5.0]]), edge_index=torch.tensor([[0, 1], [1, 0]]))
print(data)
layer = SubGraph()
for k, v in layer.state_dict().items():
    if k.endswith('weight'):
        v[:] = torch.tensor([[1.0]])
    elif k.endswith('bias'):
        v[:] = torch.tensor([1.0])

print(layer(data))
### expected answer: [[9,9], [9,9]]

x = torch.tensor([0,1,2,3,4,5,6,7,8,9])
clusters = torch.tensor([0,8,10,11,11,12,12,12,12,15])
x = global_max_pool(x, clusters, size=20)
### expected answer: [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 4, 8, 0, 0, 9, 0, 0, 0, 0]