import torch, numpy as np
import torch.nn as nn, torch.nn.functional as F

from torch.autograd import Variable
from model import utils 
import ipdb


class HyperTensorGCN(nn.Module):
    def __init__(self, E, X, args):
        """
        d: initial node-feature dimension
        h: number of hidden units
        c: number of classes
        """
        super(HyperTensorGCN, self).__init__()
        d, l, c = args.d, args.depth, args.c
        cuda = args.cuda and torch.cuda.is_available()
        
        h = [d]
        for i in range(l-1):
            power = l - i + 2
            h.append(2**power)
        h.append(c) 

        self.hgc1 = utils.HyperTensorGraphConvolution(d,16)
        self.hgc2 = utils.HyperTensorGraphConvolution(16,c)
        self.do, self.l = args.dropout, args.depth
        self.power = args.power
        self.num_sample = args.sample



    def forward(self, structure, H):
        """
        an l-layer GCN
        """
        do, l= self.do, self.l
        power = self.power
        num_sample = self.num_sample
        H = F.relu(self.hgc1(structure, H, power, num_sample))
        H = F.dropout(H, do, training=self.training)
        H = self.hgc2(structure, H, power, num_sample)      
        return F.log_softmax(H, dim=1)
