import torch
from tqdm import tqdm

from src.train import train, run
from src.constraints import ConstraintManager, HighOrderMutexConstraint, OneHotConstraint, UpperBoundConstraint, loss_pubo_with_constraints
from src.loss import loss_pubo
from src.models import Layer, LayerType, Net
from src.utils import generate_C, generate_hypergraph, get_H


if __name__ == "__main__":
    cons_manager = ConstraintManager()
    # fmt:off
    (cons_manager.add_constraint(HighOrderMutexConstraint([0, 1, 8], 1.0))
                 .add_constraint(OneHotConstraint([0, 1, 9, 266, 255, 666, 1061,233], 1, 0, "onehot"))
                 .add_constraint(UpperBoundConstraint([0, 1, 9], coefficient=200, expect=0, 
                                                      cofs=torch.tensor([3, 9, 15]), upper_bound=7))
                 .add_constraint(OneHotConstraint([0, 1, 9], 1, 0, "onehot"))
                 )
    init_feature_dim = 4096
    
    layers = [Layer(LayerType.HGNNPConv, init_feature_dim, 1024), 
              Layer(LayerType.HGNNPConv, 1024, 256, last_conv=True), 
              Layer(LayerType.LINEAR, 256, 1, drop_rate=0.5),
              ]
    # fmt:on
    k = 9
    num_nodes = 10000
    num_hyperedges = 12000
    hg = generate_hypergraph(k, num_nodes, num_hyperedges)

    for v in hg.v:
        hg.add_hyperedges([v])

    H = get_H(hg)
    print("The nth power terms are: ", set(H.sum(dim=0).tolist()))
    C = generate_C(hg.num_e)
    X = torch.randn(size=(num_nodes, init_feature_dim))
    _, outs = run(layers, X, hg, 250, loss_pubo_with_constraints, 1e-3, cons_manager=cons_manager, H=H, C=C)
    loss = loss_pubo((outs >= 0.5).float(), H, C)
    print(f"----final loss: {loss:.2f}----")
    cons_manager.is_valid(outs)
