import torch
import torch_geometric as pyg

import MC_GIN
from MCGC import MCGC


class Convolution(torch.nn.Module):
    def __init__(self, conv_type, n, d, k=4, n_layers=1):
        super(Convolution, self).__init__()
        self.conv_type = conv_type
        self.convs = torch.nn.ModuleList()
        self.norms = torch.nn.ModuleList()
        if conv_type == 'GCN':
            for i in range(n_layers):
                self.convs.append(pyg.nn.GCNConv(d,d, add_self_loops=False))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'GAT':
            for i in range(n_layers):
                self.convs.append(pyg.nn.GATConv(d,d, heads=k, concat=False, add_self_loops=False))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'GATv2':
            for i in range(n_layers):
                self.convs.append(pyg.nn.GATv2Conv(d,d, heads=k, concat=False, add_self_loops=False))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'SAGE':
            self.convs.append(pyg.nn.SAGEConv(d,d))
        elif conv_type == 'MCGC':
            self.convs.append(MCGC(n, d))
        elif conv_type == 'MC-GIN':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=k))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'MC-GIN+softmax':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=k, use_softmax=True))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'MC-GIN(k=1)':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=1))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'MC-GIN(k=3)':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=3))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'MC-GIN (1 layer MLP)':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=k, nn=torch.nn.Linear(d*2,k)))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'MC-GIN (2 layer MLP)':
            for i in range(n_layers):
                self.convs.append(MC_GIN.MC_GIN(d, d, heads=k, nn=torch.nn.Sequential(torch.nn.Linear(d*2,d),torch.nn.ReLU(),torch.nn.Linear(d,k))))
                self.norms.append(torch.nn.BatchNorm1d(d))
        elif conv_type == 'GIN':
            for i in range(n_layers):
                self.convs.append(pyg.nn.GINConv(torch.nn.Sequential(
                    torch.nn.Linear(d,d),
                    torch.nn.ReLU(),
                    torch.nn.Linear(d,d),
                    torch.nn.ReLU(),
                    torch.nn.Linear(d, d)
                )))
                self.norms.append(torch.nn.BatchNorm1d(d))
        else:
            raise ValueError(f"Unknown convolution type {conv_type}")

    def forward(self, x, edge_index):
        for i, layer in enumerate(self.convs):
            x = layer(x, edge_index)
            if i < len(self.convs) - 1:
                x = torch.nn.functional.relu(x)
        return x