import torch
import torch.nn as nn
import yaml
from torch_geometric.data.data import Data

from hmpn.get_hmpn import get_hmpn_from_graph
from hmpn.hmpn_util.calculate_max_batch_size import estimate_max_batch_size


def main():
    assert torch.cuda.is_available(), "CUDA is not available. Please check your setup."
    device = "cuda"

    # Example usage for "regular" max batch size estimation:
    simple_model = nn.Sequential(nn.Linear(1000, 10000),
                                 nn.ReLU(), nn.Linear(10000, 2))
    input_sample = torch.randn(1, 1000).to("cuda")  # Example input sample (batch size 1)
    for i in range(3):
        max_batch_size = estimate_max_batch_size(simple_model, input_sample, device="cuda", verbose=False)
        print(f"Estimated maximum batch size trial #{i + 1} for MLP model: {max_batch_size}")

    # Example usage for HMPN model:
    # load the config file
    with open('config.yaml', 'r') as file:
        config = yaml.safe_load(file)

    graph = Data(x=torch.randn(100, 3),  # 100 nodes with 3 features each
                 edge_index=torch.randint(0, 100, (2, 1000)),  # fully connected
                 edge_attr=torch.randn(1000, 2),  # 1000 edges, 2 features
                 u=torch.randn(1, 4),  # shape (1, num_features)
                 )
    graph.to(device)

    hmpn_model = get_hmpn_from_graph(example_graph=graph,
                                     latent_dimension=config.get("latent_dimension"),
                                     base_config=config.get("base"),
                                     unpack_output=True,
                                     device=device)

    for i in range(3):
        max_batch_size = estimate_max_batch_size(hmpn_model, graph, device="cuda", verbose=False)
        print(f"Estimated maximum batch size trial #{i + 1} for HMPN model: {max_batch_size}")


if __name__ == '__main__':
    main()
