import pprint

import torch
import yaml
from torch_geometric.data.hetero_data import HeteroData
from torch_geometric.data import Data

from hmpn.common.hmpn_util import make_batch
from hmpn.get_hmpn import get_hmpn_from_graph
from hmpn.hierarchical.util import build_hierarchical_graph


def main():
    # load the config file
    with open("config.yaml", "r") as file:
        config = yaml.safe_load(file)
    pprint.pprint(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n\n")

    # build a heterogeneous graph with 2 node colors "c1" and "c2". Color 1 has 2 node features, color 2 has 3.
    # There are 3 types of edges between the nodes, namely (c1, c1), (c1, c2) and (c2, c2).
    # The first two have 2 features, the last has 5. There are 4 global features
    normal_graph = Data(
        x=torch.tensor(
            [[1], [2], [3], [4], [5], [6], [7], [8]], dtype=torch.float32
        ),  # 2 nodes with 3 features each
        edge_index=torch.tensor([[0, 1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 5, 6, 7]]),
        edge_attr=torch.tensor(
            [[1], [3], [6], [8], [2], [4], [3]], dtype=torch.float32
        ),  # 4 edges, 2 features
        u=torch.tensor([[1, 2, 3, 4]], dtype=torch.float32),  # shape (1, num_features)
    )
    normal_graph.to(device)
    hierarchical_graph = build_hierarchical_graph(normal_graph, num_levels=3)
    print(hierarchical_graph["x", "level1", "x"])

    print(f"Graph: {hierarchical_graph}\n\n")

    # build the message passing network from this graph. In this case, the network is a homogeneous message passing
    # network since the graph is homogeneous
    hmpn = get_hmpn_from_graph(
        example_graph=normal_graph,
        latent_dimension=config.get("latent_dimension"),
        base_config=config.get("base"),
        unpack_output=True,
        device=device,
    )
    print(f"Network: {hmpn}\n\n")

    # run the network on the graph, getting a latent representation of all node, edge and global features as an output
    batch = make_batch(normal_graph)
    print(f"Batch: {batch}")
    out = hmpn(batch)
    print(f"Output: {out}\n\n")


if __name__ == "__main__":
    main()
