import torch
import torch.nn as nn
import torch.nn.functional as F

class NNHyperNetwork(nn.Module):
    def __init__(self, feature_dim, hidden_dim, total_params, hn_dropout):
        super().__init__()
        self.feature_dim = feature_dim
        self.hidden_dim = hidden_dim
        self.hn_dropout = hn_dropout
        self.dropout = nn.Dropout(p=self.hn_dropout)

        self.layer1 = nn.Linear(feature_dim, hidden_dim)
        self.layer2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu = nn.ReLU(inplace=True)
        self.param_generator = nn.Linear(hidden_dim, total_params)

    def forward(self, features):
        x = self.relu(self.layer1(features))
        x = F.dropout(x, p=self.hn_dropout, training=self.training)
        x = self.relu(self.layer2(x))
        x = F.dropout(x, p=self.hn_dropout, training=self.training)

        # Generate parameters
        gcn_params = self.param_generator(x)
        return gcn_params
    