import torch
from dgl import function as fn
from torch import nn
from torch.nn import functional as F

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, general_mode, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.p_neg = nn.Parameter(torch.FloatTensor([0.0 if self.use_reparameterization else 1.0]))
        self.q_pos = nn.Parameter(torch.FloatTensor([0.0]))
        self.q_neg = nn.Parameter(torch.FloatTensor([0.0]))
    
    def get_p_pos(self):
        if self.use_reparameterization:
            p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
        else:
            p_pos = self.p_pos
        return p_pos
    
    def get_p_neg(self):
        if self.use_reparameterization:
            p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
        else:
            p_neg = self.p_neg
        return p_neg
    
class Conv(nn.Module):
    r"""We modified existing implementation of GraphSAGE from DGL
    (https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/conv/sageconv.py)
    """
    def __init__(self, in_feats, out_feats, norm=None, activation=None, model_type=0):
        super(Conv, self).__init__()
        
        self._in_feats = in_feats
        self._out_feats = out_feats
        self.norm = norm
        self.activation = activation
        
        self.fc_pool = nn.Linear(in_feats, in_feats, bias=True)
        self.fc_neigh = nn.Linear(in_feats + in_feats, out_feats, bias=True)
        self.model_type = model_type
        
        if self.model_type == 2:
            self.agg_fn = GeneralPooling(in_feats, general_mode=0)
            def general_reduce_func(nodes):
                return {'neigh': torch.logsumexp(nodes.mailbox['m'], dim=-2)}
            self._reducer = lambda x, y: general_reduce_func
        
        self.reset_parameters()

    def reset_parameters(self):
        """Reinitialize learnable parameters."""
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
        nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

    def forward(self, graph, feat):
        graph = graph.local_var()
        graph.ndata['h'] = self.fc_pool(feat)
        graph.edata['_w'] = graph.edata['weight'].unsqueeze(1)
        if self.model_type == 0:
            graph.update_all(message_func = fn.u_mul_e('h', '_w', 'm'),
                             reduce_func = fn.max('m', 'neigh'))
            h_neigh = graph.ndata['neigh']
            degs = graph.in_degrees()
            h_neigh[degs == 0, :] = 0
        
        elif self.model_type == 1:
            graph.update_all(message_func = fn.u_mul_e('h', '_w', 'm'),
                             reduce_func = fn.sum('m', 'neigh'))
            h_neigh = graph.ndata['neigh']
            degs = graph.in_degrees()
            h_neigh[degs == 0, :] = 0
        else:
            hidden_dim = self._in_feats
            feat_src = graph.ndata.pop('h')
            graph.srcdata['mask'] = (feat_src >= self.agg_fn.eps).float()
            h = F.relu(feat_src)
            h[:, hidden_dim//2:][h[:, hidden_dim//2:] < self.agg_fn.eps] = 1. / self.agg_fn.eps
            graph.update_all(fn.copy_u('mask', 'm0'), fn.sum('m0', 'neigh0'))
            is_allzero = graph.dstdata['neigh0'] < 0.5

            p_pos, p_neg = self.agg_fn.get_p_pos(), -self.agg_fn.get_p_neg()
            ps = torch.stack((p_pos, p_neg), dim=0).repeat(1, hidden_dim // 2).view(-1) 
            qs = torch.stack((self.agg_fn.q_pos, self.agg_fn.q_neg), dim=0).repeat(1, hidden_dim // 2).view(-1) 

            graph.srcdata['h'] = h
            def _msg_func_gen_1(edges):
                return {'m1': torch.log(F.relu(edges.src['h'] * edges.data['_w']) + self.agg_fn.eps) * ps}
            graph.update_all(_msg_func_gen_1, fn.max('m1', 'mx'))
            def _msg_func_gen_2(edges):
                return {'m2': torch.exp(torch.log(F.relu(edges.src['h'] * edges.data['_w']) + self.agg_fn.eps) * ps - edges.dst['mx'])}
            graph.update_all(_msg_func_gen_2, fn.sum('m2', 'sumexp'))
            # print(graph.dstdata['sumexp'])
            agg_h = torch.log(graph.dstdata['sumexp'] + self.agg_fn.eps) + graph.dstdata['mx']
            
            degs = graph.in_degrees()
            degs[degs < 0.5] = 1
            degs = degs.detach()
            agg_h = torch.exp((agg_h / ps) - (qs * (torch.log(degs + self.agg_fn.eps).unsqueeze(-1))))
                
            # agg_h = torch.exp((agg_h / ps) - (qs * (torch.log(graph.in_degrees() + self.agg_fn.eps).unsqueeze(-1))))
            agg_h = agg_h * ((graph.in_degrees() > 0).unsqueeze(-1).float())
            agg_h[is_allzero] = 0.
                
            # agg_h = torch.exp(agg_h / ps)
            # agg_h = agg_h * ((1. / (graph.in_degrees() + 1e-6)).unsqueeze(-1) ** qs)
            # agg_h[is_allzero] = 0.
            
            h_neigh = agg_h
            degs = graph.in_degrees()
            h_neigh[degs == 0, :] = 0
            
        rst = self.fc_neigh(torch.cat((feat, h_neigh), dim=1))
        if self.activation is not None:
            rst = self.activation(rst)
        if self.norm is not None:
            rst = self.norm(rst)
        return rst

class MONSTOR(torch.nn.Module):
    def __init__(self, in_feats, n_hidden, n_layers, model_type=0):
        super(MONSTOR, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.acts = torch.nn.ModuleList()
        dims = [in_feats, *[n_hidden for _ in range(n_layers - 1)], 1]
        
        for i in range(n_layers):
            self.layers.append(Conv(dims[i], dims[i+1], model_type=model_type))
            self.acts.append(nn.ReLU())
        
    def forward(self, g, features, gt=None):
        graph = g.local_var()
        h = features.clone()
        for act, layer in zip(self.acts, self.layers):
            h = act(layer(graph, h))
        
        # compute upper bound of influence
        prv_diff, now = features[:, -2], features[:, -1]
        graph.ndata['prv'] = prv_diff
        graph.update_all(fn.u_mul_e('prv', 'weight', 'm'), fn.sum('m', 'delta_ub'))
        lb = now
        ub = torch.clamp((lb + graph.ndata['delta_ub'].squeeze()), min=0, max=1)
        
        if gt is not None:
            print((gt <= ub).long().sum() / gt.shape[0])
            
        return torch.min(lb + h.squeeze(), ub)
