import torch
from torch_geometric.nn import global_max_pool, global_mean_pool, global_add_pool
import numpy as np
import torch.multiprocessing as mp
from os import getpid
import torch.nn.functional as F
from torch.nn.modules.normalization import LayerNorm
from collections import defaultdict
from torch import nn


class AMP2Graph(torch.nn.Module):
    def __init__(self, in_features, hidden, out_features, message_size, graph_class, num_starts, num_messages,
                 use_skip):
        super(AMP2Graph, self).__init__()

        self.encoder = torch.nn.Linear(in_features, hidden)
        self.newstate = nn.Sequential(nn.Linear(hidden+message_size, hidden), nn.ReLU(), nn.Linear(hidden, hidden))
        self.new_message = torch.nn.Linear(hidden + message_size, message_size)
        self.decoder = torch.nn.Linear(hidden, out_features)
        self.skipdecoder = torch.nn.Linear(4*hidden, out_features)

        self.message_size = message_size
        self.graph_class = graph_class
        self.num_starts = num_starts
        self.num_messages = num_messages
        self.hidden = hidden
        self.use_skip = use_skip

    def neighbors(self, data):
        nbs = defaultdict(set)
        source, dest = data.edge_index
        for node in range(data.num_nodes):
            for i in range(len(source)):
                if int(source[i]) == node:
                    nbs[node].add(int(dest[i]))
                if int(dest[i]) == node:
                    nbs[node].add(int(source[i]))
        return nbs

    def forward(self, data):
        pred = []
        for graph in data.to_data_list():
            pred.append(self.run_on_graph(graph, None))
        return torch.cat(pred, dim=0), 0

    def run_on_graph(self, graph, opt):
        neighbors = self.neighbors(graph)
        nodestates = torch.zeros(graph.num_nodes, self.hidden)
        skipnodestates = torch.zeros(graph.num_nodes, 4 * self.hidden)

        queue = []
        for i in range(graph.num_nodes):
            queue.append((i, torch.ones(1, self.message_size)))
        messages = 0
        encoded_nodes = self.encoder(graph.x)
        predictions = defaultdict(list)
        for node in range(graph.num_nodes):
            node_encoded = encoded_nodes[node:node + 1, :]
            predictions[node].append(node_encoded)


        while queue and messages < self.num_messages:
            messages += 1
            node, message = queue.pop(0)
            features = predictions[node][-1]
            newstate = torch.relu(self.newstate(torch.cat([features, message], dim=1)))
            newmessage = self.new_message(torch.cat([newstate, message], dim=1))
            for nb in neighbors[node]:
                queue.append((nb, newmessage))
            predictions[node].append(newstate)
        for i in range(graph.num_nodes):
            if self.use_skip:
                skipstates = predictions[i][-4:]
                while len(skipstates) < 4:
                    skipstates = [torch.zeros(1, self.hidden)] + skipstates
                skipnodestates[i:i+1, :] = torch.cat(skipstates, dim=1)
            else:
                nodestates[i, :] = predictions[i][-1]
        if self.use_skip:
            predictions = global_add_pool(skipnodestates, batch=torch.zeros(graph.num_nodes).long())
            return torch.log_softmax(self.skipdecoder(predictions), dim=-1)
        else:
            predictions = global_add_pool(nodestates, batch=torch.zeros(graph.num_nodes).long())
            return torch.log_softmax(self.decoder(predictions), dim=-1)


    def reset_parameters(self):
        self.encoder.reset_parameters()
        self.decoder.reset_parameters()
        self.newstate[0].reset_parameters()
        self.newstate[2].reset_parameters()
        self.new_message.reset_parameters()
        self.skipdecoder.reset_parameters()