import os
import torch
import torch.nn as nn
from dhg.structure.hypergraphs import Hypergraph


class hds_ode(nn.Module):
    def __init__(self, in_channels, num_classes, layer_num=2, step=50, alpha_v=0.05, alpha_e=0.9,
                 use_bn=False, bias=True, drop_rate=0.15):
        super(hds_ode, self).__init__()
        hid_channels = in_channels
        self.step = step
        self.alpha_v = alpha_v
        self.alpha_e = alpha_e
        self.bn = nn.BatchNorm1d(hid_channels) if use_bn else None
        self.act = nn.ReLU(inplace=True)
        self.drop = nn.Dropout(drop_rate)
        self.layer_num = layer_num
        self.theta_vertex = nn.Linear(hid_channels, hid_channels, bias=bias)
        self.theta_hyperedge = nn.Linear(hid_channels, hid_channels, bias=bias)
        self.classifier = nn.Linear(hid_channels, num_classes)

    def forward(self, X: torch.Tensor, hg: "dhg.Hypergraph"):
        E = hg.D_e_neg_1.mm(hg.H_T).mm(X)
        W_ev = hg.D_e_neg_1.mm(hg.H_T)
        W_ve = hg.D_v_neg_1.mm(hg.H)

        for i in range(self.layer_num):
            if i % self.step == 0:
                X = X + self.act(self.theta_vertex(X))
                E = E + self.act(self.theta_hyperedge(E))
                if self.bn is not None:
                    X = self.bn(X)
                    E = self.bn(E)
            X = self.drop(X)
            newX = X - self.alpha_v * (X - W_ve.mm(E))
            newE = E - self.alpha_e * (E - W_ev.mm(X))
            X = newX
            E = newE

        X = self.classifier(X)
        return X
