import torch
from torch.nn import Softmax
from torch.nn import Linear, ModuleList, ReLU, Sequential
import torch.nn.functional as F

from torch_geometric.nn import GCNConv, GINConv, EdgeConv
from layers import RelativeEdgeConv, DistEdgeConv

class MLP(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, layers=3, varied_hidden_channels=None):
        super().__init__()
        self.lin_list = ModuleList()
        self.num_layers = layers
        assert self.num_layers >= 2
        if varied_hidden_channels is None:
            self.lin_list.append(Linear(input_channels, hidden_channels, bias=False))
            for _ in range(self.num_layers-2):
                block = Sequential(
                    Linear(hidden_channels, hidden_channels),
                    ReLU(),
                    Linear(hidden_channels, hidden_channels)
                )
                self.lin_list.append(block)
            self.lin_list.append(Linear(hidden_channels, output_channels, bias=False))
        else:
            block = Sequential(
                Linear(input_channels, varied_hidden_channels[0]),
                ReLU()
            )
            self.lin_list.append(block)
            for i in range(len(varied_hidden_channels)-1):
                block = Sequential(
                    Linear(varied_hidden_channels[i], varied_hidden_channels[i+1], bias=False),
                    ReLU()
                )
                self.lin_list.append(block)
            self.lin_list.append(Linear(varied_hidden_channels[-1], output_channels, bias=False))
            
    def forward(self, x, edge_index=None):
        for block in self.lin_list:
            x = block(x)
        return x
    
class GCN(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, layers=3):
        super().__init__()
        self.conv_list = ModuleList()
        self.num_layers = layers
        assert self.num_layers >= 2
        self.conv_list.append(GCNConv(input_channels, hidden_channels))
        for _ in range(self.num_layers-2):
            block = GCNConv(hidden_channels, hidden_channels)
            self.conv_list.append(block)
        self.conv_list.append(GCNConv(hidden_channels, output_channels))

    def forward(self, x, edge_index):
        for block in self.conv_list:
            x = block(x, edge_index)
            x = x.relu()
        return x
    
class GIN(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels, layers=3):
        super().__init__()
        self.conv_list = ModuleList()
        self.num_layers = layers
        assert self.num_layers >= 2
        nn = MLP(input_channels, hidden_channels, hidden_channels)
        self.conv_list.append(GINConv(nn=nn))
        for _ in range(self.num_layers-2):
            nn = MLP(hidden_channels, hidden_channels, hidden_channels)
            block = GINConv(nn=nn)
            self.conv_list.append(block)
        nn = MLP(hidden_channels, hidden_channels, output_channels)
        self.conv_list.append(GINConv(nn=nn))

    def forward(self, x, edge_index):
        for block in self.conv_list:
            x = block(x, edge_index)
            x = x.relu()
        return x
    
class EGConv(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super().__init__()
        nn1 = MLP(input_channels*2, hidden_channels, hidden_channels)
        nn2 = MLP(hidden_channels*2, hidden_channels, hidden_channels)
        nn3 = MLP(hidden_channels*2, hidden_channels, output_channels)
        self.conv1 = EdgeConv(nn=nn1)
        self.conv2 = EdgeConv(nn=nn2)
        self.conv3 = EdgeConv(nn=nn3)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        return x
    
class REGConv(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super().__init__()
        nn1 = MLP(input_channels, hidden_channels, hidden_channels)
        nn2 = MLP(hidden_channels, hidden_channels, hidden_channels)
        nn3 = MLP(hidden_channels, hidden_channels, output_channels)
        self.conv1 = RelativeEdgeConv(nn=nn1)
        self.conv2 = RelativeEdgeConv(nn=nn2)
        self.conv3 = RelativeEdgeConv(nn=nn3)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        return x
    
class DistConv(torch.nn.Module):
    def __init__(self, input_channels, hidden_channels, output_channels):
        super().__init__()
        nn1 = MLP(1, hidden_channels, output_channels)
        #nn2 = MLP(hidden_channels, hidden_channels, output_channels)
        self.conv1 = DistEdgeConv(nn=nn1)
        #self.conv2 = RelativeEdgeConv(nn=nn2)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)
        return x
    
    
    