from src import init, get_device, generate_hypergraph, run_pubo, run, from_hypergraph_to_graph_clique
from src import Layer, LayerType
from src.maxcut import build_hypergraph_maxcut, loss_max_cut, maxcut_evaluate
import torch

if __name__ == "__main__":
    init(cuda_index=1, reproducibility=False)
    graph = generate_hypergraph(4, 1700, 3400).to(get_device())
    maxcut_graph = build_hypergraph_maxcut(graph.e[0], start_index=0).to(get_device())
    print(graph)
    print(maxcut_graph)
    init_feature_dim = 1024
    x = torch.rand((graph.num_v, init_feature_dim))
    maxcut_x = torch.rand((maxcut_graph.num_v, init_feature_dim))
    layers = [
        # Layer(LayerType.GRAPHSAGE, init_feature_dim, 512, hidden_channels=512, num_layers=2, jk="last", drop_rate=0),
        Layer(LayerType.GAT, init_feature_dim, 512, hidden_channels=512, num_layers=2, jk="last", drop_rate=0, v2=True),
        Layer(LayerType.LINEAR, 512, 1, use_bn=True, drop_rate=0),
    ]
    edge_index = torch.tensor(from_hypergraph_to_graph_clique(maxcut_graph).e[0], dtype=torch.long).t().contiguous()
    
    loss, maxcut_outs = run_pubo(
        layers,
        maxcut_graph,
        maxcut_graph.H.to_dense(),
        torch.tensor(maxcut_graph.e[1]),
        maxcut_x,
        5000,
        1e-4,
        True
    )
    
    maxcut_evaluate(maxcut_outs, graph)
    
    loss, outs = run(
        layers,
        x,
        graph,
        3000,
        loss_max_cut,
        2e-4,
        H=graph.H.to_dense(),
        C=torch.tensor(graph.e[1]),
        edge_index=edge_index,
        clip_grad=True
    )
    
    maxcut_evaluate(outs, graph)