import torch
from src.maxcut.utils import build_hypergraph_maxcut, maxcut_evaluate, split_polynomial
from src.loss import loss_pubo
from src.train import run_pubo
from src.utils import from_file_to_hypergraph, get_H, compact_vertex_indices

from src import Layer, LayerType

if __name__ == "__main__":
    org_hg = from_file_to_hypergraph("/home/exs/work/binary-programming/data/test.txt", True)
    dst_hg = build_hypergraph_maxcut(org_hg.e[0], True, 0)
    outs = torch.tensor([[0],[1],[1],[0],[1],[1],[1],[0],[1]])
    res = maxcut_evaluate(outs, org_hg)
    print(res["cut_edges"])
    print(loss_pubo(outs,H=dst_hg.H.to_dense(),C=torch.tensor([dst_hg.e[1]])))
    
    init_feature_dim = 1024
    H = dst_hg.H.to_dense()
    C = torch.tensor([dst_hg.e[1]])
    layers = [
            Layer(LayerType.HGNNPCONV, init_feature_dim, 1024),
            Layer(LayerType.HGNNPCONV, 1024, 512, last_conv=True),
            Layer(LayerType.LINEAR, 512, 1, drop_rate=0.5),
    ]
    maxcut_x = torch.rand((dst_hg.num_v, init_feature_dim))
    acl = lambda e: -0.1 + e / 1000

    loss, maxcut_outs = run_pubo(
        layers, dst_hg, dst_hg.H.to_dense(), torch.tensor(dst_hg.e[1]), maxcut_x, 1000, 3e-4, False, acl=acl
    )
    res = maxcut_evaluate(maxcut_outs, org_hg)
    print(res["cut_edges"])
