import pprint

import torch
import yaml
from torch_geometric.data.data import Data

from hmpn.common.hmpn_util import make_batch
from hmpn.pyg_wrappers.gat import GATWrapper


def main():
    # load the config file
    with open("config.yaml", "r") as file:
        config = yaml.safe_load(file)
    pprint.pprint(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n\n")

    # build a graph with 2 nodes with 3 features, fully connected edges with 2 features and 4 global features
    graph = Data(
        x=torch.tensor(
            [[1, 2, 3], [4, 5, 6]], dtype=torch.float32
        ),  # 2 nodes with 3 features each
        edge_index=torch.tensor([[0, 1, 1, 0], [1, 0, 0, 1]]),  # fully connected
        edge_attr=torch.tensor(
            [[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32
        ),  # 4 edges, 2 features
    )
    graph.to(device)
    print(f"Graph: {graph}\n\n")

    example_graph = graph
    latent_dimension = config.get("latent_dimension")
    base_config = config.get("base")
    in_node_features = example_graph.x.shape[1]
    in_edge_features = example_graph.edge_attr.shape[1]
    create_graph_copy = base_config.get("create_graph_copy")
    assert_graph_shapes = base_config.get("assert_graph_shapes")
    stack_config = base_config.get("stack")
    embedding_config = base_config.get("embedding")
    scatter_reduce_strs = base_config.get("scatter_reduce")
    flip_edges_for_nodes = base_config.get("flip_edges_for_nodes", False)
    node_name = "node"

    gat = GATWrapper(
        in_node_features=in_node_features,
        in_edge_features=in_edge_features,
        in_global_features=None,
        latent_dimension=latent_dimension,
        scatter_reduce_strs=scatter_reduce_strs,
        stack_config=stack_config,
        embedding_config=embedding_config,
        unpack_output=True,
        create_graph_copy=create_graph_copy,
        assert_graph_shapes=assert_graph_shapes,
        flip_edges_for_nodes=flip_edges_for_nodes,
        node_name=node_name,
    ).to(device)
    print(f"Network: {gat}\n\n")

    # run the network on the graph, getting a latent representation of all node, edge and global features as an output
    batch = make_batch(graph)
    print(f"Batch: {batch}")
    out = gat(batch)
    print(f"Output: {out}\n\n")


if __name__ == "__main__":
    main()
