import torch

from torch_geometric.nn import MessagePassing
import torch.nn as nn
NONE_DUMY = "NONE_DUMY"

class GRUConv(MessagePassing):
    def __init__(self, emb_dim, aggr):
        super(GRUConv, self).__init__(aggr=aggr)
        self.rnn = torch.nn.GRUCell(emb_dim, emb_dim)

    def forward(self, x, edge_index):
        msg = self.propagate(edge_index, x=x)
        out = x+self.rnn(msg, x)
        return out

class GRUMLPConv(MessagePassing):
    def __init__(self, emb_dim, mlp_edge, aggr):
        super(GRUMLPConv, self).__init__(aggr=aggr)
        self.rnn = torch.nn.GRUCell(emb_dim, emb_dim)
        self.mlp_edge = mlp_edge

    def forward(self, x, edge_index):
        out = self.rnn(self.propagate(edge_index, x=x), x)
        return out

    def message(self, x_j, x_i):
        concatted = torch.cat((x_j, x_i), dim=1)
        return self.mlp_edge(concatted)      


class GINMLPConv(MessagePassing):
    def __init__(self, mlp, mlp_edge, aggr):
        super(GINMLPConv, self).__init__(aggr=aggr)

        self.mlp = mlp
        self.mlp_edge = mlp_edge
        self.eps = torch.nn.Parameter(torch.Tensor([0]))

    def forward(self, x, edge_index):
        out = x + self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x))

        return out

    def message(self, x_j, x_i):
        concatted = torch.cat((x_j, x_i), dim=1)
        return self.mlp_edge(concatted)

    def update(self, aggr_out):
        return aggr_out




class BFSConv(MessagePassing):
    def __init__(self, aggr = "min"):
        super().__init__(aggr=aggr)

    def forward(self, distances, edge_index):
        msg = self.propagate(edge_index, x=distances)
        return torch.minimum(msg, distances)

    def message(self, x_j):
        return x_j + 1

class BFS(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = BFSConv()
    
    def forward(self, data, starts, starting_nodes_per_graph=1):
        edge_index = data.edge_index
        distances = torch.empty(data.num_nodes, starting_nodes_per_graph).fill_(float('Inf')).to(edge_index.device)
        distances[starts] = 0
        counter = data.num_nodes + 1
        while float('Inf') in distances and counter >= 0:       #Added counter to deal with disconnected components
            distances = self.conv(distances, edge_index)
            counter = counter -1
        return distances
    




class PGNConv(MessagePassing):
    """Adapted from https://github.com/google-deepmind/clrs/blob/64e016998f14305f94cf3f6d19ac9d7edc39a185/clrs/_src/processors.py#L330"""
    def __init__(self, in_channels, out_channels, aggr, mid_act=None, activation=nn.ReLU()):
        super(PGNConv, self).__init__(aggr=aggr)
        self.in_channels = in_channels
        self.mid_channels = out_channels
        self.mid_act = mid_act
        self.out_channels = out_channels
        self.activation = activation

        # Message MLPs
        self.m_1 = nn.Linear(in_channels, self.mid_channels) # source node
        self.m_2 = nn.Linear(in_channels, self.mid_channels) # target node
        
        self.msg_mlp = nn.Sequential(
            nn.ReLU(),
            nn.Linear(self.mid_channels, self.mid_channels),
            nn.ReLU(),
            nn.Linear(self.mid_channels, self.mid_channels)
        )

        # Edge weight scaler
        self.edge_weight_scaler = nn.Linear(1, self.mid_channels)

        # Output MLP
        self.o1 = nn.Linear(in_channels, out_channels) # skip connection
        self.o2 = nn.Linear(self.mid_channels, out_channels)

        
        # We do not support graph level features for now

    def forward(self, x, edge_index, edge_weight=NONE_DUMY):
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight)
        h_1 = self.o1(x)
        h_2 = self.o2(out)
        out = h_1 + h_2
        if self.activation is not None:
            out = self.activation(out)
        return out
    
    def message(self, x_j, x_i, edge_weight=NONE_DUMY):
        # j is source, i is target
        msg_1 = self.m_1(x_j)
        msg_2 = self.m_2(x_i)
        
        
        msg = msg_1 + msg_2        
        if edge_weight is not NONE_DUMY:
            msg_e = self.edge_weight_scaler(edge_weight.reshape(-1, 1))
            msg = msg + msg_e
        
        msg = self.msg_mlp(msg)


        if self.mid_act is not None:
            msg = self.mid_act(msg)

        return msg
