import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.autograd as autograd

from torch_geometric.nn import global_mean_pool

class RWLayer(nn.Module):
    def __init__(self, node_encoder, n_graphs, n_nodes, hidden_dim, lamda):
        super().__init__()
        self.n_nodes = n_nodes
        self.n_graphs = n_graphs
        self.node_encoder = node_encoder
        self.lamda = lamda
        self.w_adj = nn.Parameter(torch.FloatTensor(n_graphs, (n_nodes*(n_nodes-1))//2))
        self.w_x = nn.Parameter(torch.FloatTensor(n_graphs, n_nodes, hidden_dim))
        self.b = torch.nn.Parameter(torch.tensor(0.))
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
        self.tanh = nn.Tanh()
        self.init_weights()

    def init_weights(self):
        self.w_adj.data.uniform_(-1, 1)
        self.w_x.data.uniform_(-1, 1)
               
    def forward(self, z, edge_index, features, px, node_depth):
        w_adj = torch.zeros(self.n_graphs, self.n_nodes, self.n_nodes, device=features.device)
        idx = torch.triu_indices(self.n_nodes, self.n_nodes, 1)
        w_adj[:,idx[0],idx[1]] = self.relu(self.tanh(self.w_adj))
        w_adj = w_adj + torch.transpose(w_adj, 1, 2)
        x = self.node_encoder(features, node_depth.view(-1,))
        S = self.sigmoid(torch.einsum("ab,cdb->acd", (x, self.w_x))+self.b)
        Z = z.view(features.size(0), self.n_graphs, self.n_nodes)
        Z = torch.mul(Z, S)
        x = torch.einsum("abc,bdc->abd", (Z, w_adj))
        x = torch.index_select(x, 0, edge_index[0,:])
        x = torch.reshape(x, (x.size(0), -1))
        idx = torch.transpose(edge_index[1,:].repeat(x.size(1), 1), 0, 1)
        x = torch.zeros(features.size(0), x.size(1), device=x.device).scatter_add_(0, idx, x)
        x = torch.mul(x.view(-1), S.view(-1))
        z = px + self.lamda*x
        return z


def forward_iteration(f, x0, max_iter=50, tol=1e-2):
    f0 = f(x0)
    res = []
    for k in range(max_iter):
        x = f0
        f0 = f(x)
        res.append((f0 - x).norm().item() / (1e-5 + f0.norm().item()))
        if (res[-1] < tol):
            break
    return f0, res


class DEQFixedPoint(nn.Module):
    def __init__(self, f, solver, n_graphs, n_nodes, **kwargs):
        super().__init__()
        self.f = f
        self.solver = solver
        self.n_graphs = n_graphs
        self.n_nodes = n_nodes
        self.kwargs = kwargs
        
    def forward(self, edge_index, x, px, node_depth):
        # compute forward pass and re-engage autograd tape
        with torch.no_grad():
            z, self.forward_res = self.solver(lambda z : self.f(z, edge_index, x, px, node_depth), 
                torch.zeros(self.n_graphs*self.n_nodes*x.size(0), device=x.device), **self.kwargs)
        z = self.f(z, edge_index, x, px, node_depth)
        
        # set up Jacobian vector product (without additional forward calls)
        z0 = z.clone().detach().requires_grad_()
        f0 = self.f(z0, edge_index, x, px, node_depth)
        def backward_hook(grad):
            g, self.backward_res = self.solver(lambda y : autograd.grad(f0, z0, y, retain_graph=True)[0] + grad,
                                               grad, **self.kwargs)
            return g
                
        z.register_hook(backward_hook)
        return z


class GRWNN(nn.Module):
    def __init__(self, num_vocab, max_seq_len, node_encoder, n_graphs, n_nodes, hidden_dim, lamda, dropout):
        super(GRWNN, self).__init__()
        self.n_nodes = n_nodes
        self.n_graphs = n_graphs
        self.max_seq_len = max_seq_len
        f = RWLayer(node_encoder, n_graphs, n_nodes, hidden_dim, lamda)
        self.deq = DEQFixedPoint(f, forward_iteration, n_graphs, n_nodes, tol=1e-4, max_iter=100)
        self.ln = torch.nn.ModuleList()
        self.ln.append(nn.LayerNorm([n_graphs, n_nodes]))
        for i in range(2):
            self.ln.append(nn.LayerNorm(n_graphs))
        self.graph_pred_linear_lst = torch.nn.ModuleList()
        for i in range(max_seq_len):
            self.graph_pred_linear_lst.append(torch.nn.Linear(n_graphs, num_vocab))
        
        self.dropout = nn.Dropout(p=dropout)
        self.relu = nn.ReLU()
   
    def forward(self, data):
        x, edge_attr, edge_index, batch, node_depth = data.x, data.edge_attr, data.edge_index, data.batch, data.node_depth
        
        qx = torch.ones(x.size(0), self.n_graphs, self.n_nodes, device=x.device)
        px = qx

        z = self.deq(edge_index, x, px.view(-1), node_depth)
        Z = z.view(x.size(0), self.n_graphs, self.n_nodes)
        Z = torch.mul(Z, qx)
        Z = self.ln[0](Z)
        x = torch.sum(Z, dim=2)
        x = self.ln[1](x)
        out = global_mean_pool(x, batch)
        out = self.ln[2](out)
        out = self.dropout(out)
        pred_list = []
        for i in range(self.max_seq_len):
            pred_list.append(self.graph_pred_linear_lst[i](out))

        return pred_list
