from encoder.encoder import Encoder
from torch_geometric.nn import Node2Vec as PyGNode2Vec
import torch.nn as nn
import numpy as np


class Node2Vec(PyGNode2Vec, Encoder):
    def __init__(self, data, args):
        model_cfg = args.encoder.node2vec.model
        super().__init__(
            edge_index=data.edge_index,
            embedding_dim=args.embedding_dim,
            walk_length=model_cfg.walk_length,
            context_size=model_cfg.context_size,
            walks_per_node=model_cfg.walks_per_node,
            num_negative_samples=model_cfg.num_negative_samples,
            p=model_cfg.p,
            q=model_cfg.q,
            sparse=model_cfg.sparse
        )

    def get_embeddings(self):
        return self.embedding.weight
