
import numpy as np

import torch
from torch_geometric.nn import GAT, GCN, GraphSAGE
from torch_geometric.loader import NeighborLoader


def compute_embeddings(graph, saved_model_file: str, config: dict,
                       num_neighbors: list[int] = None, batch_size: int = 1024) -> np.ndarray:

    if num_neighbors is None:
        num_neighbors = [15, 5]

    # load the model
    if config["model"]["name"] == "GAT":
        model = GAT(in_channels=graph.x.shape[1],
                    edge_dim=graph.edge_attr.size(-1) if config["training"]["use_edge_features"] else None,
                    **config["model"]["params"])
    elif config["model"]["name"] == "GCN":
        assert not config["training"]["use_edge_features"]
        model = GCN(in_channels=graph.x.shape[1], **config["model"]["params"])
    elif config["model"]["name"] == "GraphSAGE":
        assert not config["training"]["use_edge_features"]
        model = GraphSAGE(in_channels=graph.x.shape[1], **config["model"]["params"])
    else:
        raise ValueError
    model.load_state_dict(torch.load(saved_model_file))
    model.eval()

    # create the batchs
    loader = NeighborLoader(data=graph, num_neighbors=num_neighbors,
                            input_nodes=torch.arange(graph.num_nodes),
                            batch_size=batch_size,
                            shuffle=False)

    # fill the embeddings
    embeddings = torch.zeros(graph.num_nodes, model.out_channels)
    with torch.no_grad():
        for batch in loader:
            if config["training"]["use_edge_features"]:
                out = model(x=batch.x, edge_index=batch.edge_index, edge_attr=batch.edge_attr)
            else:
                out = model(x=batch.x, edge_index=batch.edge_index)
            embeddings[batch.input_id] = out[:batch.batch_size].cpu()
    embeddings = embeddings.numpy()
    embeddings /= np.linalg.norm(embeddings, axis=1, keepdims=True)

    return embeddings

