import pprint

import numpy as np
import torch
import yaml
from torch_geometric.data.data import Data

from hmpn.common.hmpn_util import make_batch
from hmpn.get_hmpn import get_hmpn_from_graph


def main():
    # load the config file
    with open("config.yaml", "r") as file:
        config = yaml.safe_load(file)
    pprint.pprint(config)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Device: {device}\n\n")

    # build a graph with 2 nodes with 3 features, fully connected edges with 2 features and 4 global features
    example_graph = Data(
        x=torch.tensor(
            [[1, 2, 3], [4, 5, 6]], dtype=torch.float32
        ),  # 2 nodes with 3 features each
        edge_index=torch.tensor([[0, 1, 1, 0], [1, 0, 0, 1]]),  # fully connected
        edge_attr=torch.tensor(
            [[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32
        ),  # 4 edges, 2 features
        u=torch.tensor([[1, 2, 3, 4]], dtype=torch.float32),  # shape (1, num_features)
    )
    example_graph.to(device)
    print(f"Graph: {example_graph}\n\n")

    # build the message passing network from this graph. In this case, the network is a homogeneous message passing
    # network since the graph is homogeneous
    hmpn = get_hmpn_from_graph(
        example_graph=example_graph,
        latent_dimension=config.get("latent_dimension"),
        base_config=config.get("base"),
        unpack_output=True,
        device=device,
    )
    print(f"Network: {hmpn}\n\n")
    initial = hmpn(make_batch(example_graph))

    from time import perf_counter
    import tqdm

    all_timings = np.empty((100, 100))
    for i in tqdm.tqdm(range(100)):
        for batch_size in range(1, 101):
            # create batch_size graphs with 100 nodes and 1000 edges each
            random_graphs = [
                Data(
                    x=torch.randn(100, 3),  # 100 nodes with 3 features each
                    edge_index=torch.randint(0, 100, (2, 1000)),  # fully connected
                    edge_attr=torch.randn(1000, 2),  # 1000 edges, 2 features
                    u=torch.randn(1, 4),  # shape (1, num_features)
                )
                for _ in range(batch_size)
            ]
            random_graphs = [graph.to(device) for graph in random_graphs]
            batch = make_batch(random_graphs)

            start = perf_counter()
            out = hmpn(batch)
            torch.cuda.synchronize()
            end = perf_counter()
            all_timings[i, batch_size - 1] = end - start

    print(f"Average time for batch size 0-100: {sum(all_timings) / len(all_timings)}")
    import matplotlib.pyplot as plt

    plt.plot(all_timings.mean(axis=0))
    plt.xlabel("Batch size")
    plt.ylabel("Time (s)")
    plt.show()


if __name__ == "__main__":
    main()
