import torch
import torch.nn as nn
import dgl.function as fn
from dgl.nn.pytorch.softmax import edge_softmax
from config import args

class distance(nn.Module):
    def __init__(self):
        super(distance, self).__init__()

    def forward(self, graph, feats):
        graph = graph.local_var().to(args.device)
        feats = feats.view(-1, 1, feats.shape[1])
        graph.ndata.update({'ftl': feats, 'ftr': feats})

        graph.apply_edges(fn.u_sub_v('ftl', 'ftr', 'diff'))
        e = graph.edata.pop('diff')
        e = torch.exp((-1.0 / 100) * torch.sum(torch.abs(e), dim=-1))

        e = edge_softmax(graph, e)
        return e


def KLDiv(graph, edgex, edgey):
    with graph.local_scope():
        nnode = graph.number_of_nodes()
        graph.ndata.update({'kldiv': torch.ones(nnode, 1).to(edgex.device)})
        diff = edgey * (torch.log(edgey) - torch.log(edgex))
        graph.edata.update({'diff': diff})
        graph.update_all(fn.u_mul_e('kldiv', 'diff', 'm'), fn.sum('m', 'kldiv'))
        return torch.mean(torch.flatten(graph.ndata['kldiv']))


def middle_loss(g, middle_feats_t, middle_feats_s):
    local_model = distance()
    dist_t = local_model(g, middle_feats_t)
    dist_s = local_model(g, middle_feats_s)
    return KLDiv(g, dist_s, dist_t)
