import torch
from einops import rearrange
from torch import nn

from .base import Embedding


class RESCALEmbedding(Embedding):
    def __init__(self, num_entities: int, num_relations: int, dimension: int):
        super().__init__(num_entities, num_relations, dimension)

        # Entity embeddings: (num_entities x dimension)
        self.entity_embeddings = nn.Embedding(num_entities, dimension)

        # Relation embeddings: (num_relations x (dimension * dimension))
        # Stored as flattened matrices for easier optimization
        self.relation_embeddings = nn.Embedding(num_relations, dimension * dimension)

        # Initialize embeddings
        nn.init.xavier_uniform_(self.entity_embeddings.weight)
        nn.init.xavier_uniform_(self.relation_embeddings.weight)

    def embed_entities(self, entities_batch: torch.Tensor) -> torch.Tensor:
        return self.entity_embeddings(entities_batch)

    def embed_relations(self, relations_batch: torch.Tensor) -> torch.Tensor:
        return self.relation_embeddings(relations_batch)

    def get_all_entity_embeddings(self) -> torch.Tensor:
        return self.entity_embeddings.weight

    def regularization_term(self, lmbda: float) -> torch.Tensor:
        """L2 regularization on all embeddings"""
        entity_reg = torch.sum(self.entity_embeddings.weight**2)
        relation_reg = torch.sum(self.relation_embeddings.weight**2)
        return lmbda * (entity_reg + relation_reg)

    def get_relation_matrix(self, relation_idx: torch.Tensor) -> torch.Tensor:
        """Helper method to reshape relation embeddings into matrices

        Args:
            relation_idx: Tensor of relation indices

        Returns:
            Tensor of shape (batch_size, dimension, dimension)

        """
        flat_matrices = self.embed_relations(relation_idx)
        return rearrange(
            flat_matrices,
            "b (d1 d2) -> b d1 d2",
            d1=self.dimension,
            d2=self.dimension,
        )
