import torch
import torch.nn as nn
import yaml
from torch_geometric.data.data import Data

from hmpn.get_hmpn import get_hmpn_from_graph
from hmpn.hmpn_util.calculate_max_batch_size import estimate_max_batch_size


def main():
    assert torch.cuda.is_available(), "CUDA is not available. Please check your setup."
    device = "cuda"

    # Example usage for "regular" max batch size estimation:
    simple_model = nn.Sequential(nn.Linear(1000, 10000), nn.ReLU(), nn.Linear(10000, 2))
    input_sample = torch.randn(1, 1000).to(
        "cuda"
    )  # Example input sample (batch size 1)
    for i in range(3):
        max_batch_size = estimate_max_batch_size(
            simple_model, input_sample, device="cuda", verbose=False
        )
        print(
            f"Estimated maximum batch size trial #{i + 1} for MLP model: {max_batch_size}"
        )

    # Example usage for HMPN model:
    # load the config file
    with open("config.yaml", "r") as file:
        config = yaml.safe_load(file)

    graph = Data(
        x=torch.randn(100, 3),  # 100 nodes with 3 features each
        edge_index=torch.randint(0, 100, (2, 1000)),  # fully connected
        edge_attr=torch.randn(1000, 2),  # 1000 edges, 2 features
        u=torch.randn(1, 4),  # shape (1, num_features)
    )
    graph.to(device)

    hmpn_model = get_hmpn_from_graph(
        example_graph=graph,
        latent_dimension=config.get("latent_dimension"),
        base_config=config.get("base"),
        unpack_output=True,
        device=device,
    )

    for i in range(3):
        max_batch_size = estimate_max_batch_size(
            hmpn_model, graph, device="cuda", verbose=False
        )
        print(
            f"Estimated maximum batch size trial #{i + 1} for HMPN model: {max_batch_size}"
        )


if __name__ == "__main__":
    main()
