import torch
import torch_geometric
import torch_scatter

class GraphConvModel(torch.nn.Module):
    def __init__(self, in_dim, out_dim, net_arch=[192, 192, 192]):
        super().__init__()
        layers = []
        last_layer_dim = in_dim
        for layer_dim in net_arch:
            layers.append(torch_geometric.nn.GraphConv(last_layer_dim, layer_dim))
            last_layer_dim = layer_dim
        self.layers = torch.nn.ModuleList(layers)
        self.linear = torch.nn.Linear(last_layer_dim, out_dim)
        
    def forward(self, x, edge_index, edge_attr=None, batch=None):
        for layer in self.layers:
            x = layer(x, edge_index)
            x = torch.nn.functional.relu(x)
        x = self.linear(x)
        return x

class GraphConvGlobalModel(torch.nn.Module):
    def __init__(self, in_dim, out_dim, net_arch=[192, 192, 192]):
        super().__init__()
        layers = []
        last_layer_dim = in_dim
        for layer_dim in net_arch:
            layers.append(torch_geometric.nn.GraphConv(last_layer_dim, layer_dim))
            last_layer_dim = layer_dim
        self.layers = torch.nn.ModuleList(layers)
        self.global_dim = int(out_dim/2)
        self.global_layer = torch.nn.Linear(last_layer_dim + self.global_dim, last_layer_dim)
        self.linear = torch.nn.Linear(last_layer_dim, out_dim)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        for layer in self.layers:
            x = layer(x, edge_index)
            x = torch.nn.functional.relu(x)
        global_messages = torch_geometric.nn.global_mean_pool(x[:, -self.global_dim:], batch)
        global_messages = global_messages[batch]
        x = torch.cat((x, global_messages), dim=1)
        x = self.global_layer(x)
        x = torch.nn.functional.relu(x)
        x = self.linear(x)
        return x

class GRUGraphModel(torch.nn.Module):

    def __init__(self, in_dim, out_dim, net_arch=[128, 128, 128, 128]):
        super().__init__()
        hidden_dim = net_arch[0]
        self.embedding_layer = torch.nn.Linear(in_dim, hidden_dim)
        self.gru = torch_geometric.nn.GatedGraphConv(hidden_dim, len(net_arch))
        self.out_layer = torch.nn.Linear(hidden_dim, out_dim)

    def forward(self, x, edge_index, edge_attr=None, batch=None):
        x = self.embedding_layer(x)
        x = torch.nn.functional.relu(x)
        x = self.gru(x, edge_index)
        x = self.out_layer(x)
        return x


class MLP(torch.nn.Module):
    
    def __init__(self, in_dim, out_dim, layers=[128, 128], act='ReLU'):
        super().__init__()
        last_dim = in_dim
        model = []
        act = vars(torch.nn)[act]
        for layer in layers:
            model.append(torch.nn.Linear(last_dim, layer))
            last_dim = layer
            model.append(act())
        model.append(torch.nn.Linear(last_dim, out_dim))
        self.model = torch.nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class EdgeModel(torch.nn.Module):

    def __init__(self, node_dim, edge_dim, global_dim, out_dim, layers=[128]):
        super().__init__()
        self.mlp = MLP(2*node_dim + edge_dim + global_dim, out_dim, layers=layers, act='ReLU')

    def forward(self, src, dest, edge_attr, u, batch):
        out = torch.cat([src, dest, edge_attr, u[batch]], dim=1)
        return self.mlp(out)

class NodeModel(torch.nn.Module):

    def __init__(self, node_dim, edge_dim, global_dim, out_dim, layers=[128]):
        super().__init__()
        self.mlp1 = MLP(node_dim + edge_dim, out_dim, layers=layers, act='ReLU')
        self.mlp2 = MLP(node_dim + out_dim + global_dim, out_dim, layers=layers, act='ReLU')

    def forward(self, x, edge_index, edge_attr, u, batch):
        row, col = edge_index
        out = torch.cat([x[row], edge_attr], dim=1)
        out = self.mlp1(out)
        out = torch_scatter.scatter_mean(out, col, dim=0, dim_size=x.size(0))
        out = torch.cat([x, out, u[batch]], dim=1)
        out = self.mlp2(out)
        return out

class GlobalModel(torch.nn.Module):

    def __init__(self, node_dim, edge_dim, global_dim, out_dim, layers=[128]):
        super().__init__()
        self.mlp = MLP(global_dim + node_dim, out_dim, layers=layers, act='ReLU')

    def forward(self, x, edge_index, edge_attr, u, batch):
        out = torch.cat([u, torch_scatter.scatter_mean(x, batch, dim=0)], dim=1)
        return self.mlp(out)

class MetaGraphModel_Depth2(torch.nn.Module):
    
    def __init__(self, in_dim_node, in_dim_edge, out_dim, net_arch=[128]):
        super().__init__()
        self.global_dim = 32
        self.hidden_dim = 64
        self.layer1 = torch_geometric.nn.MetaLayer(
            EdgeModel(in_dim_node, in_dim_edge, in_dim_node, self.hidden_dim, layers=net_arch),
            NodeModel(in_dim_node, self.hidden_dim, self.global_dim, self.hidden_dim, layers=net_arch),
            GlobalModel(self.hidden_dim, self.hidden_dim, self.global_dim, self.global_dim, layers=net_arch)
        )
        self.layer2 = torch_geometric.nn.MetaLayer(
            EdgeModel(self.hidden_dim, self.hidden_dim, self.global_dim, self.hidden_dim, layers=net_arch),
            NodeModel(self.hidden_dim, self.hidden_dim, self.global_dim, self.hidden_dim, layers=net_arch),
            GlobalModel(self.hidden_dim, self.hidden_dim, self.global_dim, self.global_dim, layers=net_arch)
        )

        self.output_layer = EdgeModel(self.hidden_dim, self.hidden_dim, self.global_dim, out_dim, layers=net_arch)

    def forward(self, x, edge_index, edge_attr, batch):
        u = torch_scatter.scatter_mean(torch.zeros_like(x), batch, dim=0)
        x, edge_attr, u = self.layer1(x, edge_index, edge_attr, u, batch)
        x, edge_attr, u = self.layer2(x, edge_index, edge_attr, u, batch)
        row, col = edge_index
        return self.output_layer(x[row], x[col], edge_attr, u, batch[row])

class MetaGraphModel_Depth1(torch.nn.Module):
    
    def __init__(self, in_dim_node, in_dim_edge, out_dim, net_arch=[128]):
        super().__init__()
        self.global_dim = 32
        self.hidden_dim = 80
        self.layer1 = torch_geometric.nn.MetaLayer(
            EdgeModel(in_dim_node, in_dim_edge, in_dim_node, self.hidden_dim, layers=net_arch),
            NodeModel(in_dim_node, self.hidden_dim, in_dim_node, self.hidden_dim, layers=net_arch),
            GlobalModel(self.hidden_dim, self.hidden_dim, in_dim_node, self.global_dim, layers=net_arch)
        )

        self.output_layer = EdgeModel(self.hidden_dim, self.hidden_dim, self.global_dim, out_dim, layers=net_arch)

    def forward(self, x, edge_index, edge_attr, batch):
        u = torch_scatter.scatter_mean(torch.zeros_like(x), batch, dim=0)
        x, edge_attr, u = self.layer1(x, edge_index, edge_attr, u, batch)
        row, col = edge_index
        return self.output_layer(x[row], x[col], edge_attr, u, batch[row])

class MLP_Local(torch.nn.Module):
    def __init__(self, in_dim, num_classes, num_heads, net_arch=[256, 256]):
        super().__init__()
        self.num_heads = num_heads
        self.num_classes = num_classes
        self.mlp = MLP(in_dim, num_classes*num_heads, layers=net_arch)

    def forward(self, x, edge_index, edge_attr, batch):
        # We only care about x
        x = self.mlp(x)
        return x.view(-1, self.num_classes, self.num_heads)

class MLP_Global(torch.nn.Module):

    def __init__(self, in_dim, num_classes, num_heads, net_arch=[256, 256]):
        super().__init__()
        self.mlp = MLP(in_dim, num_classes, layers=net_arch)
        self.num_heads = 1

    def forward(self, x, edge_index, edge_attr, batch):
        # We only care about x
        x = self.mlp(x)
        return x
        