from collections.abc import Callable

import torch
from torch import nn

from kge.types import EntityIDTensor, RelationIDTensor

from .base import Embedding


def l2_deviation(x: torch.Tensor) -> torch.Tensor:
    """Calculate the embedding L2-norm deviation from the unit ball."""
    norms = torch.norm(x, p=2, dim=1)
    return torch.mean((norms - 1.0) ** 2)


def l2_norm(x: torch.Tensor) -> torch.Tensor:
    """Calculate the L2-norm of the embeddings."""
    return torch.norm(x, p=2) ** 2


def l2_constrainer(x: torch.Tensor) -> torch.Tensor:
    """Constrain the embeddings to have unit L2-norm."""
    return x / torch.norm(x, p=2, dim=1, keepdim=True)


def identity_constrainer(x: torch.Tensor) -> torch.Tensor:
    """Identity constrainer."""
    return x


class RealEmbedding(Embedding):
    def __init__(
        self,
        num_entities: int,
        num_relations: int,
        dimension: int,
        entity_init_fn: Callable[[torch.Tensor], None] = nn.init.xavier_uniform_,
        relation_init_fn: Callable[[torch.Tensor], None] = nn.init.xavier_uniform_,
        entity_constrainer: Callable[[torch.Tensor], torch.Tensor] = identity_constrainer,
        relation_regularization: Callable[[torch.Tensor], torch.Tensor] = l2_deviation,
    ):
        """Initialize real-valued embeddings for entities and relations.

        Args:
            num_entities: Number of entities in the knowledge graph
            num_relations: Number of relation types in the knowledge graph
            dimension: Dimension of the embedding vectors
            entity_init_fn: Function to initialize entity embeddings.
                Should be in-place. (default: xavier_uniform)
            relation_init_fn: Function to initialize relation embeddings.
                Should be in-place. (default: xavier_uniform)
            entity_constrainer: Function to constrain entity embeddings.
                Applied at each forward pass. (default: l2_constrainer)
            relation_regularization: Function to collect the regularization term for relations.
                Soft penalty term. (default: l2_deviation)

        """
        super().__init__(num_entities, num_relations, dimension)
        self.entity_embeddings = nn.Embedding(num_entities, dimension)
        self.relation_embeddings = nn.Embedding(num_relations, dimension)
        entity_init_fn(self.entity_embeddings.weight)
        relation_init_fn(self.relation_embeddings.weight)

        self.entity_constrainer = entity_constrainer
        self.relation_regularization = relation_regularization

    def embed_entities(self, entities_batch: EntityIDTensor) -> torch.Tensor:
        embeddings = self.entity_embeddings(entities_batch)
        return self.entity_constrainer(embeddings)

    def embed_relations(self, relations_batch: RelationIDTensor) -> torch.Tensor:
        return self.relation_embeddings(relations_batch)

    def get_all_entity_embeddings(self) -> torch.Tensor:
        return self.entity_constrainer(self.entity_embeddings.weight)

    def regularization_term(self) -> torch.Tensor:
        return 1e-3 * self.relation_regularization(
            self.entity_embeddings.weight,
        ) + 1e-5 * self.relation_regularization(self.relation_embeddings.weight)


class ComplexEmbedding(Embedding):
    def __init__(
        self,
        num_entities: int,
        num_relations: int,
        dimension: int,
        entity_init_fn: Callable[[torch.Tensor], None] = nn.init.xavier_uniform_,
        relation_init_fn: Callable[[torch.Tensor], None] = nn.init.xavier_uniform_,
        entity_constrainer: Callable[[torch.Tensor], torch.Tensor] = identity_constrainer,
        relation_regularization: Callable[[torch.Tensor], torch.Tensor] = l2_deviation,
    ):
        """Initialize complex-valued embeddings for entities and relations.

        Args:
            num_entities: Number of entities in the knowledge graph
            num_relations: Number of relation types in the knowledge graph
            dimension: Dimension of the embedding vectors (for both real and imaginary parts)
            entity_init_fn: Function to initialize entity embeddings.
                Should be in-place. (default: xavier_uniform)
            relation_init_fn: Function to initialize relation embeddings.
                Should be in-place. (default: xavier_uniform)
            entity_constrainer: Function to constrain entity embeddings.
                Applied at each forward pass. (default: l2_constrainer)
            relation_regularization: Function to collect the regularization term for relations.
                Soft penalty term. (default: l2_deviation)

        """
        super().__init__(num_entities, num_relations, dimension)
        self.entity_embeddings_real = nn.Embedding(num_entities, dimension)
        self.entity_embeddings_im = nn.Embedding(num_entities, dimension)
        self.relation_embeddings_real = nn.Embedding(num_relations, dimension)
        self.relation_embeddings_im = nn.Embedding(num_relations, dimension)
        entity_init_fn(self.entity_embeddings_real.weight)
        entity_init_fn(self.entity_embeddings_im.weight)
        relation_init_fn(self.relation_embeddings_real.weight)
        relation_init_fn(self.relation_embeddings_im.weight)

        self.entity_constrainer = entity_constrainer
        self.relation_regularization = relation_regularization

    def embed_entities(self, entities_batch: EntityIDTensor) -> torch.Tensor:
        real = self.entity_constrainer(self.entity_embeddings_real(entities_batch))
        im = self.entity_constrainer(self.entity_embeddings_im(entities_batch))
        return torch.complex(real, im)

    def embed_relations(self, relations_batch: RelationIDTensor) -> torch.Tensor:
        return torch.complex(
            self.relation_embeddings_real(relations_batch),
            self.relation_embeddings_im(relations_batch),
        )

    def get_all_entity_embeddings(self) -> torch.Tensor:
        real = self.entity_constrainer(
            self.entity_embeddings_real.weight
        )  # (num_entities, dimension)
        im = self.entity_constrainer(self.entity_embeddings_im.weight)  # (num_entities, dimension)
        return torch.cat([real, im], dim=-1)  # (num_entities, 2 * dimension)

    def regularization_term(self) -> torch.Tensor:
        return self.relation_regularization(
            self.relation_embeddings_real.weight,
        ) + self.relation_regularization(self.relation_embeddings_im.weight)
