import torch

from src.loss import loss_pubo
from src.models import Layer, LayerType
from src.train import run, run_pubo
from src.utils import generate_C, get_H, generate_hypergraph


if __name__ == "__main__":
    # Define initial feature dimension
    init_feature_dim = 2048

    # Build the network architecture
    layers = [Layer(LayerType.HGNNPConv, init_feature_dim, 1024), Layer(LayerType.HGNNPConv, 1024, 512, last_conv=True), Layer(LayerType.LINEAR, 512, 1, drop_rate=0.5)]

    # Define the number of nodes and generate hypergraph (hg) and H matrix
    k = 9
    num_nodes = 10000
    num_hyperedges = 12000
    hg = generate_hypergraph(k, num_nodes, num_hyperedges)
    for v in hg.v:
        hg.add_hyperedges([v])

    H = get_H(hg)
    print("The nth power terms are: ", set(H.sum(dim=0).tolist()))
    C = generate_C(hg.num_e)
    print(hg.e[0])
    print(C)

    # Generate random input features
    X = torch.randn(size=(num_nodes, init_feature_dim))

    # Run PUBO training
    _, outs = run(layers, X, hg, 600, loss_pubo, 1e-3, H=H, C=C)
    loss = loss_pubo((outs >= 0.5).float(), H, C)
    print(loss)
