import pprint

import torch
import yaml
from torch_geometric.data.data import Data

from hmpn.common.hmpn_util import make_batch
from hmpn.pyg_wrappers.gat import GATWrapper


def main():
    # load the config file
    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)
    pprint.pprint(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n\n")

    # build a graph with 2 nodes with 3 features, fully connected edges with 2 features and 4 global features
    graph = Data(x=torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32),  # 2 nodes with 3 features each
                 edge_index=torch.tensor([[0, 1, 1, 0], [1, 0, 0, 1]]),  # fully connected
                 edge_attr=torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32),  # 4 edges, 2 features
                 )
    graph.to(device)
    print(f"Graph: {graph}\n\n")

    example_graph = graph
    latent_dimension = config.get("latent_dimension")
    base_config = config.get("base")
    in_node_features = example_graph.x.shape[1]
    in_edge_features = example_graph.edge_attr.shape[1]
    create_graph_copy = base_config.get("create_graph_copy")
    assert_graph_shapes = base_config.get("assert_graph_shapes")
    stack_config = base_config.get("stack")
    embedding_config = base_config.get("embedding")
    scatter_reduce_strs = base_config.get("scatter_reduce")
    flip_edges_for_nodes = base_config.get('flip_edges_for_nodes', False)
    node_name = "node"

    gat = GATWrapper(in_node_features=in_node_features,
                     in_edge_features=in_edge_features,
                     in_global_features=None,
                     latent_dimension=latent_dimension,
                     scatter_reduce_strs=scatter_reduce_strs,
                     stack_config=stack_config,
                     embedding_config=embedding_config,
                     unpack_output=True,
                     create_graph_copy=create_graph_copy,
                     assert_graph_shapes=assert_graph_shapes,
                     flip_edges_for_nodes=flip_edges_for_nodes,
                     node_name=node_name).to(device)
    print(f"Network: {gat}\n\n")

    # run the network on the graph, getting a latent representation of all node, edge and global features as an output
    batch = make_batch(graph)
    print(f"Batch: {batch}")
    out = gat(batch)
    print(f"Output: {out}\n\n")


if __name__ == '__main__':
    main()
