import dgl.nn.pytorch as dglnn
import torch
import torch as th
import torch.nn as nn
from dgl import function as fn
from dgl._ffi.base import DGLError
from dgl.nn.pytorch.utils import Identity
from dgl.ops import edge_softmax
from dgl.utils import expand_as_pair
import torch.nn.functional as F

class GeneralPooling(nn.Module):
    def __init__(self, hidden_dim, num_heads=1, general_mode=0, eps=1e-12):
        super(GeneralPooling, self).__init__()
        self.eps = eps
        self.hidden_dim = hidden_dim
        self.use_pos = ((general_mode // 2) == 0)
        self.use_neg = ((general_mode % 2) == 0)
        self.use_reparameterization = True
        self.p_pos = nn.Parameter(th.FloatTensor([(0.0 if self.use_reparameterization else 1.0) for _ in range(num_heads)]))
        self.p_neg = nn.Parameter(th.FloatTensor([(0.0 if self.use_reparameterization else 1.0) for _ in range(num_heads)]))
        self.q_pos = nn.Parameter(th.FloatTensor([0.0 for _ in range(num_heads)]))
        self.q_neg = nn.Parameter(th.FloatTensor([0.0 for _ in range(num_heads)]))
    
    def get_p_pos(self):
        if self.use_reparameterization:
            p_pos = 1. + torch.log(torch.exp(self.p_pos) + 1.)
        else:
            p_pos = self.p_pos
        return p_pos
    
    def get_p_neg(self):
        if self.use_reparameterization:
            p_neg = 1. + torch.log(torch.exp(self.p_neg) + 1.)
        else:
            p_neg = self.p_neg
        return p_neg
    
    def forward(self, g, h):
        if self.use_pos:
            if self.use_neg:
                h_pos = F.relu(h[:, :, :, :self.hidden_dim//2])
            else:
                h_pos = F.relu(h)
            mask_pos = h_pos < self.eps
            allzero_pos = mask_pos.all(dim=1, keepdim=False)
            
            if self.use_reparameterization:
                p_pos = 1. + th.log(th.exp(self.p_pos) + 1.)
            else:
                p_pos = self.p_pos
                
            pos = th.exp(th.logsumexp((th.log(h_pos + self.eps)) * p_pos, dim=1) / p_pos)
            pos = pos * ((1. / h_pos.shape[1]) ** self.q_pos)
            pos[allzero_pos] = 0.
        
        if self.use_neg:
            if self.use_pos:
                h_neg = F.relu(h[:, :, :, self.hidden_dim//2:])
            else:
                h_neg = F.relu(h)
            mask_neg = h_neg < self.eps
            allzero_neg = mask_neg.all(dim=1, keepdim=False)
            h_neg[h_neg < self.eps] = 1. / self.eps
            
            if self.use_reparameterization:
                p_neg = 1. + th.log(th.exp(self.p_neg) + 1.)
            else:
                p_neg = self.p_neg
            
            neg = th.exp(-th.logsumexp(-(th.log(h_neg + self.eps)) * p_neg, dim=1) / p_neg)
            neg = neg * ((1. / h_neg.shape[1]) ** self.q_neg)
            neg[allzero_neg] = 0. 
        
        if self.use_pos and self.use_neg:
            return th.cat((pos, neg), dim=-1)
        elif self.use_pos:
            return pos
        elif self.use_neg:
            return neg
        else:
            return None

class Bias(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.bias = nn.Parameter(torch.Tensor(size))

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.zeros_(self.bias)

    def forward(self, x):
        return x + self.bias


class GCN(nn.Module):
    def __init__(self, in_feats, n_hidden, n_classes, n_layers, activation, dropout, use_linear):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.use_linear = use_linear

        self.convs = nn.ModuleList()
        if use_linear:
            self.linear = nn.ModuleList()
        self.bns = nn.ModuleList()

        for i in range(n_layers):
            in_hidden = n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            bias = i == n_layers - 1

            self.convs.append(dglnn.GraphConv(in_hidden, out_hidden, "both", bias=bias))
            if use_linear:
                self.linear.append(nn.Linear(in_hidden, out_hidden, bias=False))
            if i < n_layers - 1:
                self.bns.append(nn.BatchNorm1d(out_hidden))

        self.dropout0 = nn.Dropout(min(0.1, dropout))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, graph, feat):
        h = feat
        h = self.dropout0(h)

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)

            if self.use_linear:
                linear = self.linear[i](h)
                h = conv + linear
            else:
                h = conv

            if i < self.n_layers - 1:
                h = self.bns[i](h)
                h = self.activation(h)
                h = self.dropout(h)

        return h


class GATConv(nn.Module):
    def __init__(
        self,
        in_feats,
        out_feats,
        num_heads=1,
        feat_drop=0.0,
        attn_drop=0.0,
        negative_slope=0.2,
        residual=False,
        activation=None,
        allow_zero_in_degree=False,
        norm="none",
    ):
        super(GATConv, self).__init__()
        if norm not in ("none", "both"):
            raise DGLError('Invalid norm value. Must be either "none", "both".' ' But got "{}".'.format(norm))
        self._num_heads = num_heads
        self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
        self._norm = norm
        if isinstance(in_feats, tuple):
            self.fc_src = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
            self.fc_dst = nn.Linear(self._in_dst_feats, out_feats * num_heads, bias=False)
        else:
            self.fc = nn.Linear(self._in_src_feats, out_feats * num_heads, bias=False)
        self.attn_l = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
        self.attn_r = nn.Parameter(torch.FloatTensor(size=(1, num_heads, out_feats)))
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        if residual:
            if self._in_dst_feats != out_feats:
                self.res_fc = nn.Linear(self._in_dst_feats, num_heads * out_feats, bias=False)
            else:
                self.res_fc = Identity()
        else:
            self.register_buffer("res_fc", None)
        self.reset_parameters()
        self._activation = activation
        
        self.agg_fn = GeneralPooling(out_feats, general_mode=0, num_heads=self._num_heads)
        
    def reset_parameters(self):
        gain = nn.init.calculate_gain("relu")
        if hasattr(self, "fc"):
            nn.init.xavier_normal_(self.fc.weight, gain=gain)
        else:
            nn.init.xavier_normal_(self.fc_src.weight, gain=gain)
            nn.init.xavier_normal_(self.fc_dst.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)
        if isinstance(self.res_fc, nn.Linear):
            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)

    def set_allow_zero_in_degree(self, set_value):
        self._allow_zero_in_degree = set_value

    def forward(self, graph, feat):
        with graph.local_scope():
            def general_reduce_func(nodes):
                return {'ft': self.agg_fn(None, nodes.mailbox['m'])}
            
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
                    assert False

            if isinstance(feat, tuple):
                h_src = self.feat_drop(feat[0])
                h_dst = self.feat_drop(feat[1])
                if not hasattr(self, "fc_src"):
                    self.fc_src, self.fc_dst = self.fc, self.fc
                feat_src, feat_dst = h_src, h_dst
                feat_src = self.fc_src(h_src).view(-1, self._num_heads, self._out_feats)
                feat_dst = self.fc_dst(h_dst).view(-1, self._num_heads, self._out_feats)
            else:
                h_src = h_dst = self.feat_drop(feat)
                feat_src, feat_dst = h_src, h_dst
                feat_src = feat_dst = self.fc(h_src).view(-1, self._num_heads, self._out_feats)
                if graph.is_block:
                    feat_dst = feat_src[: graph.number_of_dst_nodes()]

            if self._norm == "both":
                degs = graph.out_degrees().float().clamp(min=1)
                norm = torch.pow(degs, -0.5)
                shp = norm.shape + (1,) * (feat_src.dim() - 1)
                norm = torch.reshape(norm, shp)
                feat_src = feat_src * norm

            # NOTE: GAT paper uses "first concatenation then linear projection"
            # to compute attention scores, while ours is "first projection then
            # addition", the two approaches are mathematically equivalent:
            # We decompose the weight vector a mentioned in the paper into
            # [a_l || a_r], then
            # a^T [Wh_i || Wh_j] = a_l Wh_i + a_r Wh_j
            # Our implementation is much efficient because we do not need to
            # save [Wh_i || Wh_j] on edges, which is not memory-efficient. Plus,
            # addition could be optimized with DGL's built-in function u_add_v,
            # which further speeds up computation and saves memory footprint.
            el = (feat_src * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (feat_dst * self.attn_r).sum(dim=-1).unsqueeze(-1)
            graph.srcdata.update({"el": el})
            graph.dstdata.update({"er": er})
            # compute edge attention, el and er are a_l Wh_i and a_r Wh_j respectively.
            graph.apply_edges(fn.u_add_v("el", "er", "e"))
            e = self.leaky_relu(graph.edata.pop("e"))
            
            hidden_dim = self._out_feats
            graph.srcdata['mask'] = (feat_src >= 1e-12).float()
            graph.update_all(fn.copy_u('mask', 'm0'), fn.sum('m0', 'neigh0'))
            is_allzero = graph.dstdata['neigh0'] < 0.5

            h = F.relu(feat_src) + 1e-12
            h[:, :, self._out_feats//2:][h[:, :, self._out_feats//2:] < 1e-11] = 1e12
            
            p_pos, p_neg = self.agg_fn.get_p_pos(), -self.agg_fn.get_p_neg()
            ps = torch.stack((p_pos, p_neg), dim=-1).view(-1).unsqueeze(-1).repeat(1, hidden_dim // 2).view(-1, hidden_dim) 
            qs = torch.stack((self.agg_fn.q_pos, self.agg_fn.q_neg), dim=-1).view(-1).unsqueeze(-1).repeat(1, hidden_dim // 2).view(-1, hidden_dim) 
            # print(ps, qs)
            # print(torch.stack((p_pos, p_neg), dim=-1).view(-1).unsqueeze(-1))
            
            h = h ** ps
            graph.srcdata['h'] = h
            # graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
            graph.edata["a"] = (self.attn_drop(edge_softmax(graph, e)) + 1e-6) ** ps
            # print(graph.srcdata['h'].shape, graph.edata['a'].shape)
            graph.update_all(fn.u_mul_e('h', 'a', 'm1'), fn.sum('m1', 'sumexp'))
            _agg_h = torch.log(graph.dstdata['sumexp'] + 1e-12)
            degs = graph.in_degrees()
            degs[degs < 0.5] = 1
            degs = degs.detach()
            agg_h = torch.exp((_agg_h / ps) - (qs * (torch.log(degs + self.agg_fn.eps).unsqueeze(-1).unsqueeze(-1))))
            agg_h = agg_h * ((graph.in_degrees() > 0).unsqueeze(-1).unsqueeze(-1).float())
            agg_h[is_allzero] = 0.
            # print(agg_h.shape)
            rst = agg_h
            
            if self._norm == "both":
                degs = graph.in_degrees().float().clamp(min=1)
                norm = torch.pow(degs, 0.5)
                shp = norm.shape + (1,) * (feat_dst.dim() - 1)
                norm = torch.reshape(norm, shp)
                rst = rst * norm

            # residual
            if self.res_fc is not None:
                resval = self.res_fc(h_dst).view(h_dst.shape[0], -1, self._out_feats)
                rst = rst + resval
            # activation
            if self._activation is not None:
                rst = self._activation(rst)
            return rst


class GAT(nn.Module):
    def __init__(
            self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none"
    ):
        super().__init__()
        self.in_feats = in_feats
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.num_heads = n_heads

        self.convs = nn.ModuleList()
        self.linear = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.biases = nn.ModuleList()

        for i in range(n_layers):
            in_hidden = n_heads * n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            # in_channels = n_heads if i > 0 else 1
            out_channels = n_heads

            self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm))

            self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
            if i < n_layers - 1:
                self.bns.append(nn.BatchNorm1d(out_channels * out_hidden))
                
        self.fnn = nn.Linear(n_classes, n_classes, bias=False)
        self.bias_last = Bias(n_classes)

        self.dropout0 = nn.Dropout(min(0.1, dropout))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation


    def forward(self, graph, feat, perturb=None):
        h = feat
        h = self.dropout0(h)

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)
            linear = self.linear[i](h).view(conv.shape)

            h = conv + linear

            if i < self.n_layers - 1:
                h = h.flatten(1)
                h = self.bns[i](h)
                h = self.activation(h)
                h = self.dropout(h)

        h = h.mean(1)
        h = self.fnn(h)
        h = self.bias_last(h)
        return h

class GAT_embed(nn.Module):
    def __init__(
        self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none"
    ):
        super().__init__()
        self.in_feats = in_feats
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.num_heads = n_heads

        self.convs = nn.ModuleList()
        self.linear = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.biases = nn.ModuleList()

        for i in range(n_layers):
            in_hidden = n_heads * n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            # in_channels = n_heads if i > 0 else 1
            out_channels = n_heads

            self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm))

            self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
            if i < n_layers - 1:
                self.bns.append(nn.BatchNorm1d(out_channels * out_hidden))

        self.bias_last = Bias(n_classes)

        self.dropout0 = nn.Dropout(min(0.1, dropout))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

        self.embed = nn.Linear(self.in_feats, self.in_feats)

    def forward(self, graph, feat, perturb=None):
        h = feat
        h = self.embed(h) if perturb is None else self.embed(h) + perturb
        # h = self.dropout0(h)

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)
            linear = self.linear[i](h).view(conv.shape)

            h = conv + linear

            if i < self.n_layers - 1:
                h = h.flatten(1)
                h = self.bns[i](h)
                h = self.activation(h)
                h = self.dropout(h)

        h = h.mean(1)
        h = self.bias_last(h)

        return h

class GAT_no_bn(nn.Module):
    def __init__(
        self, in_feats, n_classes, n_hidden, n_layers, n_heads, activation, dropout=0.0, attn_drop=0.0, norm="none"
    ):
        super().__init__()
        self.in_feats = in_feats
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.n_layers = n_layers
        self.num_heads = n_heads

        self.convs = nn.ModuleList()
        self.linear = nn.ModuleList()
        self.bns = nn.ModuleList()
        self.biases = nn.ModuleList()

        for i in range(n_layers):
            in_hidden = n_heads * n_hidden if i > 0 else in_feats
            out_hidden = n_hidden if i < n_layers - 1 else n_classes
            # in_channels = n_heads if i > 0 else 1
            out_channels = n_heads

            self.convs.append(GATConv(in_hidden, out_hidden, num_heads=n_heads, attn_drop=attn_drop, norm=norm))

            self.linear.append(nn.Linear(in_hidden, out_channels * out_hidden, bias=False))
            if i < n_layers - 1:
                self.bns.append(nn.BatchNorm1d(out_channels * out_hidden))

        self.bias_last = Bias(n_classes)

        self.dropout0 = nn.Dropout(min(0.1, dropout))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation

    def forward(self, graph, feat):
        h = feat
        h = self.dropout0(h)

        for i in range(self.n_layers):
            conv = self.convs[i](graph, h)
            linear = self.linear[i](h).view(conv.shape)

            h = conv + linear

            if i < self.n_layers - 1:
                h = h.flatten(1)
                # h = self.bns[i](h)
                h = self.activation(h)
                h = self.dropout(h)

        h = h.mean(1)
        h = self.bias_last(h)

        return h