import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from config import args
from utils import sparse_mx_to_torch_sparse_tensor, homo_adj_to_symmetric_norm, hete_adj_to_symmetric_norm


class GGCNlayer(nn.Module):
    def __init__(self, in_features, out_features, use_degree=True, use_sign=True, use_decay=True, scale_init=0.5, deg_intercept_init=0.5):
        super(GGCNlayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fcn = nn.Linear(in_features, out_features)
        self.use_degree = use_degree
        self.use_sign = use_sign
        if use_degree:
            if use_decay:
                self.deg_coeff = nn.Parameter(torch.tensor([0.5,0.0]))
            else:
                self.deg_coeff = nn.Parameter(torch.tensor([deg_intercept_init,0.0]))
        if use_sign:
            self.coeff = nn.Parameter(0*torch.ones([3]))
            if use_decay:
                self.scale = nn.Parameter(2*torch.ones([1]))
            else:
                self.scale = nn.Parameter(scale_init*torch.ones([1]))
        self.sftmax = nn.Softmax(dim=-1)
        self.sftpls = nn.Softplus(beta=1)


    
    def forward(self, h, adj, degree_precompute):
        if self.use_degree:
            sc = self.deg_coeff[0]*degree_precompute+self.deg_coeff[1]
            sc = self.sftpls(sc)

        Wh = self.fcn(h)
        if self.use_sign:
            prod = torch.matmul(Wh, torch.transpose(Wh, 0, 1))
            sq = torch.unsqueeze(torch.diag(prod),1)
            scaling = torch.matmul(sq, torch.transpose(sq, 0, 1))
            e = prod/torch.max(torch.sqrt(scaling),1e-9*torch.ones_like(scaling))
            e = e-torch.diag(torch.diag(e))
            if self.use_degree:
                attention = e*adj*sc
            else:
                attention = e*adj
            
            attention_pos = F.relu(attention)
            attention_neg = -F.relu(-attention)
            prop_pos = torch.matmul(attention_pos, Wh)
            prop_neg = torch.matmul(attention_neg, Wh)
        
            coeff = self.sftmax(self.coeff)
            scale = self.sftpls(self.scale)
            result = scale*(coeff[0]*prop_pos+coeff[1]*prop_neg+coeff[2]*Wh)

        else:
            if self.use_degree:
                prop = torch.matmul(adj*sc, Wh)
            else:
                prop = torch.matmul(adj, Wh)
            
            result = prop
                 
        return result


class GGCNlayer_SP(nn.Module):
    def __init__(self, in_features, out_features, use_degree=True, use_sign=True, use_decay=True, scale_init=0.5, deg_intercept_init=0.5):
        super(GGCNlayer_SP, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.fcn = nn.Linear(in_features, out_features)
        self.use_degree = use_degree
        self.use_sign = use_sign
        if use_degree:
            if use_decay:
                self.deg_coeff = nn.Parameter(torch.tensor([0.5,0.0]))
            else:
                self.deg_coeff = nn.Parameter(torch.tensor([deg_intercept_init,0.0]))
        if use_sign:
            self.coeff = nn.Parameter(0*torch.ones([3]))
            self.adj_remove_diag = None
            if use_decay:
                self.scale = nn.Parameter(2*torch.ones([1]))
            else:
                self.scale = nn.Parameter(scale_init*torch.ones([1]))
        self.sftmax = nn.Softmax(dim=-1)
        self.sftpls = nn.Softplus(beta=1)
    
    def precompute_adj_wo_diag(self, adj):
        adj_i = adj._indices()
        adj_v = adj._values()
        adj_wo_diag_ind = (adj_i[0,:]!=adj_i[1,:])
        self.adj_remove_diag = torch.sparse.FloatTensor(adj_i[:,adj_wo_diag_ind], adj_v[adj_wo_diag_ind], adj.size())
                        
    def non_linear_degree(self, a, b, s):
        i = s._indices()
        v = s._values()
        return torch.sparse.FloatTensor(i, self.sftpls(a*v+b), s.size())
    
    def get_sparse_att(self, adj, Wh):
        i = adj._indices()
        Wh_1 = Wh[i[0,:],:]
        Wh_2 = Wh[i[1,:],:]
        sim_vec = F.cosine_similarity(Wh_1, Wh_2)
        sim_vec_pos = F.relu(sim_vec)
        sim_vec_neg = -F.relu(-sim_vec)
        return torch.sparse.FloatTensor(i, sim_vec_pos, adj.size()), torch.sparse.FloatTensor(i, sim_vec_neg, adj.size())
    
    def forward(self, h, adj, degree_precompute):
        if self.use_degree:
            sc = self.non_linear_degree(self.deg_coeff[0], self.deg_coeff[1], degree_precompute)

        Wh = self.fcn(h)
        if self.use_sign:
            self.precompute_adj_wo_diag(adj)
        if self.use_sign:
            e_pos, e_neg = self.get_sparse_att(adj, Wh)
            if self.use_degree:
                attention_pos = self.adj_remove_diag*sc*e_pos
                attention_neg = self.adj_remove_diag*sc*e_neg
            else:
                attention_pos = self.adj_remove_diag*e_pos
                attention_neg = self.adj_remove_diag*e_neg
            
            prop_pos = torch.sparse.mm(attention_pos, Wh)
            prop_neg = torch.sparse.mm(attention_neg, Wh)
        
            coeff = self.sftmax(self.coeff)
            scale = self.sftpls(self.scale)
            result = scale*(coeff[0]*prop_pos+coeff[1]*prop_neg+coeff[2]*Wh)

        else:
            if self.use_degree:
                prop = torch.sparse.mm(adj*sc, Wh)
            else:
                prop = torch.sparse.mm(adj, Wh)
            
            result = prop
        return result
        
        
class GGCN(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, num_layers, dropout, decay_rate, exponent, use_degree=True, use_sign=True, use_decay=True, scale_init=0.5, deg_intercept_init=0.5, bn=False, ln=False):
        super(GGCN, self).__init__()
        self.dropout = dropout

        self.use_graph_op = True
        self.pre_graph_op = None
        
        self.convs = nn.ModuleList()
        model_sel = GGCNlayer
        self.convs.append(model_sel(feat_dim, hidden_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        for _ in range(num_layers-2):
            self.convs.append(model_sel(hidden_dim, hidden_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        self.convs.append(model_sel(hidden_dim, output_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        self.fcn = nn.Linear(feat_dim, hidden_dim)
        self.act_fn = F.elu
        self.dropout = dropout
        self.use_decay = use_decay
        if self.use_decay:
            self.decay = decay_rate
            self.exponent = exponent
        self.use_degree = use_degree

        self.bn = bn
        self.ln = ln
        self.norms = nn.ModuleList()
        if bn:
            for _ in range(num_layers-1):
                self.norms.append(nn.BatchNorm1d(hidden_dim))
        if ln:
            for _ in range(num_layers-1):
                self.norms.append(nn.LayerNorm(hidden_dim))

        self.post_graph_op = None
    
    def precompute_degree_d(self, adj):
        diag_adj = torch.diag(adj)
        diag_adj = torch.unsqueeze(diag_adj, dim=1)
        degree_precompute = diag_adj/torch.max(adj, 1e-9*torch.ones_like(adj))-1
        return degree_precompute
    
    
    def preprocess(self, adj, feature, homo=args.homo):
        self.pre_msg_learnable = False
        self.processed_feature = feature

        if homo:
            adj = homo_adj_to_symmetric_norm(adj, r=0.5)
        else:
            adj = hete_adj_to_symmetric_norm(adj, r=0.5)

        self.adj = sparse_mx_to_torch_sparse_tensor(adj).to_dense()


    def postprocess(self, adj, output):
        if self.post_graph_op is not None:
            output = F.softmax(output, dim=1)
            output = output.detach().numpy()
            output = self.post_graph_op.propagate(adj, output)
            output = self.post_msg_op.aggregate(output)

        return output

    # a wrapper of the forward function
    def model_forward(self, idx, device):
        return self.forward(idx, device)


    def forward(self, idx, device):

        processed_feature = None
        if self.pre_msg_learnable is False:
            processed_feature = self.processed_feature.to(device)
        else:
            transferred_feat_list = [feat.to(
                device) for feat in self.processed_feat_list]
            processed_feature = self.pre_msg_op.aggregate(
                transferred_feat_list)

        adj = self.adj.to(device)
        x = processed_feature
        if self.use_degree:
            degree_precompute = self.precompute_degree_d(adj)
        x = F.dropout(x, self.dropout, training=self.training)
        layer_previous = self.fcn(x)
        layer_previous = self.act_fn(layer_previous)
        layer_inner = self.convs[0](x, adj, degree_precompute)

        for i,con in enumerate(self.convs[1:]):
            if self.bn or self.ln:
                layer_inner = self.norms[i](layer_inner)
            layer_inner = self.act_fn(layer_inner)
            layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
            if i==0:
                layer_previous = layer_inner + layer_previous
            else:
                if self.use_decay:
                    coeff = math.log(self.decay/(i+2)**self.exponent+1)
                else:
                    coeff = 1
                layer_previous = coeff*layer_inner + layer_previous
            layer_inner = con(layer_previous,adj, degree_precompute)
        return layer_inner[idx]


class GGCNSP(nn.Module):
    def __init__(self, feat_dim, hidden_dim, output_dim, num_layers, dropout, decay_rate, exponent, use_degree=True, use_sign=True, use_decay=True, scale_init=0.5, deg_intercept_init=0.5, bn=False, ln=False):
        super(GGCNSP, self).__init__()
        self.dropout = dropout

        self.use_graph_op = True
        self.pre_graph_op = None
        
        self.convs = nn.ModuleList()
        model_sel = GGCNlayer_SP
        self.convs.append(model_sel(feat_dim, hidden_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        for _ in range(num_layers-2):
            self.convs.append(model_sel(hidden_dim, hidden_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        self.convs.append(model_sel(hidden_dim, output_dim, use_degree, use_sign, use_decay, scale_init, deg_intercept_init))
        self.fcn = nn.Linear(feat_dim, hidden_dim)
        self.act_fn = F.elu
        self.dropout = dropout
        self.use_decay = use_decay
        if self.use_decay:
            self.decay = decay_rate
            self.exponent = exponent
        self.use_degree = use_degree

        self.bn = bn
        self.ln = ln
        self.norms = nn.ModuleList()
        if bn:
            for _ in range(num_layers-1):
                self.norms.append(nn.BatchNorm1d(hidden_dim))
        if ln:
            for _ in range(num_layers-1):
                self.norms.append(nn.LayerNorm(hidden_dim))

        self.post_graph_op = None
    
    def precompute_degree_s(self, adj):
        adj_i = adj._indices()
        adj_v = adj._values()
        adj_diag_ind = (adj_i[0,:]==adj_i[1,:])
        adj_diag = adj_v[adj_diag_ind]
        v_new = torch.zeros_like(adj_v)
        for i in range(adj_i.shape[1]):
            v_new[i] = adj_diag[adj_i[0,i]]/adj_v[i]-1
        degree_precompute = torch.sparse.FloatTensor(adj_i, v_new, adj.size())
        return degree_precompute
    
    
    def preprocess(self, adj, feature, homo=args.homo):
        self.pre_msg_learnable = False
        self.processed_feature = feature

        if homo:
            adj = homo_adj_to_symmetric_norm(adj, r=0.5)
        else:
            adj = hete_adj_to_symmetric_norm(adj, r=0.5)

        self.adj = sparse_mx_to_torch_sparse_tensor(adj)
        
    def postprocess(self, adj, output):
        if self.post_graph_op is not None:
            output = F.softmax(output, dim=1)
            output = output.detach().numpy()
            output = self.post_graph_op.propagate(adj, output)
            output = self.post_msg_op.aggregate(output)

        return output

    # a wrapper of the forward function
    def model_forward(self, idx, device):
        return self.forward(idx, device)


    def forward(self, idx, device):

        processed_feature = None
        if self.pre_msg_learnable is False:
            processed_feature = self.processed_feature.to(device)
        else:
            transferred_feat_list = [feat.to(
                device) for feat in self.processed_feat_list]
            processed_feature = self.pre_msg_op.aggregate(
                transferred_feat_list)

        adj = self.adj.to(device)
        x = processed_feature
        if self.use_degree:
            degree_precompute = self.precompute_degree_s(adj)
        x = F.dropout(x, self.dropout, training=self.training)
        layer_previous = self.fcn(x)
        layer_previous = self.act_fn(layer_previous)
        layer_inner = self.convs[0](x, adj, degree_precompute)

        for i,con in enumerate(self.convs[1:]):
            if self.bn or self.ln:
                layer_inner = self.norms[i](layer_inner)
            layer_inner = self.act_fn(layer_inner)
            layer_inner = F.dropout(layer_inner, self.dropout, training=self.training)
            if i==0:
                layer_previous = layer_inner + layer_previous
            else:
                if self.use_decay:
                    coeff = math.log(self.decay/(i+2)**self.exponent+1)
                else:
                    coeff = 1
                layer_previous = coeff*layer_inner + layer_previous
            layer_inner = con(layer_previous,adj, degree_precompute)
        return layer_inner[idx]