import math
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.cuda.current_device()
torch.cuda._initialized = True


class SHSC(nn.Module):
    def __init__(self, in_dim, out_dim, degree, alpha, args=None):
        super(SHSC, self).__init__()
        self.out_dim = out_dim
        self.SHSC_layer = SHSC_layer(degree=degree, alpha=alpha, args=args)
        self.W = nn.Linear(in_dim, out_dim, bias=False)

        self.degree = degree
        self.alpha = alpha
        self.dropout = nn.Dropout(p=args.dropout)

    def forward(self, input, adj=None, G=None):
        input  = self.dropout(input)
        return self.W(self.SHSC_layer(input, G=G))

    def get_outdim(self):
        return self.out_dim

    def __repr__(self):
        return self.__class__.__name__ + ' (degree(K): ' \
               + str(self.degree) + ' alpha: ' \
               + str(self.alpha) + ')'


class SHSC_layer(nn.Module):
    def __init__(self, degree, alpha, args=None):
        super(SHSC_layer, self).__init__()

        self.degree = degree
        self.args = args
        self.alpha = alpha

    def forward(self, input, G=None, adj=None):
        ori_features = input
        emb = input
        features = input

        for i in range(self.degree):
            features = self.alpha * torch.spmm(G, features)
            emb = emb + features
        emb = emb / self.degree

        emb = self.args.beta*emb + (1-self.args.beta)* ori_features

        return emb

    def __repr__(self):
        return self.__class__.__name__ + ' (degree(K): ' \
               + str(self.degree) + ' alpha: ' \
               + str(self.alpha) + ')'
