import torch
import torch.nn as nn
import torch.nn.functional as F
from layers import JointAttentionLayer, StructuralSubspaceLayer


class SSGAT(nn.Module):
    def __init__(self, nfeat, nhid, nclass, dropout, alpha, scale, nheads, roughconstrain):
        """Dense version of GAT."""
        super(SSGAT, self).__init__()
        self.dropout = dropout
        self.roughconstrain = roughconstrain

        self.attentions = [
            JointAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, scale=scale,
                                roughconstrain=roughconstrain, concat=True) for _ in range(nheads)]

        for i, attention in enumerate(self.attentions):
            self.add_module('attention_{}'.format(i), attention)

        # self.out_att = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha,
        #                                    autobalance=autobalance, roughconstrain=roughconstrain, concat=False)

        self.out_att1 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att2 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att3 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att4 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att5 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att6 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att7 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)
        self.out_att8 = JointAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, scale=scale,
                                            roughconstrain=roughconstrain, concat=False)

    def forward(self, x, edge, cadinality, ssc_coef1):
        x = F.dropout(x, self.dropout, training=self.training)
        # all layers using same subspace
        x = torch.cat([att(x, edge, cadinality, ssc_coef1) for att in self.attentions], dim=1)
        x = F.dropout(x, self.dropout, training=self.training)

        out = self.out_att1(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att2(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att3(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att4(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att5(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att6(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att7(x, edge, cadinality, ssc_coef1)
        out = out + self.out_att8(x, edge, cadinality, ssc_coef1)
        out = out.div(8.0)
        out = F.elu(out)
        return F.log_softmax(out, dim=1)
        # x = F.elu(self.out_att(x, edge, ssc_coef1))
        # return F.log_softmax(x, dim=1)


class Structuralsubspace(nn.Module):
    def __init__(self, nsample):
        super(Structuralsubspace, self).__init__()
        self.nsample = nsample

        self.subspace1 = StructuralSubspaceLayer(self.nsample)

    def forward(self, edget, edget_v):
        selfrep1, ssc_coef1 = self.subspace1(edget, edget_v)
        return selfrep1, ssc_coef1
