from pathlib import Path
import torch
import numpy as np


MODELS_DIR = Path(__file__).parents[1] / "models/linked"


def main():
    # Iterate over the folders
    for folder in MODELS_DIR.iterdir():
        if folder.is_dir():
            model_path = folder / "model.pt"
            embedding_path = folder / "entity_embeddings.npy"
            if model_path.exists() and not embedding_path.exists():
                # Load the model weights
                state_dict = torch.load(model_path, weights_only=True)

                # Access the entity embeddings
                entity_embeddings = (
                    state_dict["state_dict"][
                        "entity_representations.0._embeddings.weight"
                    ]
                    .cpu()
                    .detach()
                    .numpy()
                )
                # Save the embeddings as a numpy file
                np.save(embedding_path, entity_embeddings)
                print(
                    f"Saved embeddings from {model_path}, shape: {entity_embeddings.shape}"
                )
            else:
                print(f"No model file found in {folder} or embeddings already exist.")
    return


if __name__ == "__main__":
    main()
