import pprint

import torch
import yaml
from torch_geometric.data.hetero_data import HeteroData

from hmpn.common.hmpn_util import get_default_edge_relation, make_batch
from hmpn.get_hmpn import get_hmpn_from_graph


def main():
    # load the config file
    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)
    pprint.pprint(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n\n")

    # build a heterogeneous graph with 2 node colors "c1" and "c2". Color 1 has 2 node features, color 2 has 3.
    # There are 3 types of edges between the nodes, namely (c1, c1), (c1, c2) and (c2, c2).
    # The first two have 2 features, the last has 5. There are 4 global features
    c1_edge = get_default_edge_relation("c1", "c1")
    c1c2_edge = get_default_edge_relation("c1", "c2")
    c2_edge = get_default_edge_relation("c2", "c2")
    graph_dict = {
        "c1": {"x": torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)},  # c1 has 2 features
        "c2": {"x": torch.tensor([[7, 8, 9], [10, 11, 12]], dtype=torch.float32)},  # c2 has 3 features
        c1_edge: {"edge_index": torch.tensor([[0, 1, 2, 0], [1, 0, 0, 1]]),  # 4 edges, 2 features
                  "edge_attr": torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32)},
        c1c2_edge: {"edge_index": torch.tensor([[0, 1, 2], [1, 0, 0]]),  # 3 edges, 2 features
                    "edge_attr": torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float32)},
        c2_edge: {"edge_index": torch.tensor([[0, 1], [0, 0]]),  # 2 edges, 5 features
                  "edge_attr": torch.tensor([[1, 2, 3, 4, 5], [6, 7, 8, 9, 10]],
                                            dtype=torch.float32)},
        "u": torch.tensor([[1, 2, 3, 4]], dtype=torch.float32)
    }
    graph = HeteroData(graph_dict)
    graph.to(device)
    print(f"Graph: {graph}\n\n")

    # build the message passing network from this graph. In this case, the network is a homogeneous message passing
    # network since the graph is homogeneous
    hmpn = get_hmpn_from_graph(example_graph=graph,
                               latent_dimension=config.get("latent_dimension"),
                               base_config=config.get("base"),
                               unpack_output=True,
                               device=device)
    print(f"Network: {hmpn}\n\n")

    # run the network on the graph, getting a latent representation of all node, edge and global features as an output
    batch = make_batch(graph)
    print(f"Batch: {batch}")
    out = hmpn(batch)
    print(f"Output: {out}\n\n")


if __name__ == '__main__':
    main()
