import torch
from torch import Tensor, device, nn


class SubEmbedding(nn.Module):

    def __init__(self, num_embeddings, embedding_dim, num_spaces=2, **kwargs):
        """
        k-sub-embedding layer

        :param num_embeddings: the number of whole embeddings
        :param embedding_dim: dimension of the embedding(not sub-embedding)
        :param num_spaces: the number of splitted sub-embedding. The paper denoted as k
        :param kwargs: parameters that used in nn.Embedding
        """

        super(SubEmbedding, self).__init__()

        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        self.num_spaces = num_spaces
        # The number of each sub-embedding
        n = int(math.pow(num_embeddings, 1 / num_spaces)) + 1 # n = exp(log(N) / k) + 1
        self.n = n

        num_embedding_list = [n] * self.num_spaces
        sub_embedding_dims = [embedding_dim // num_spaces] * self.num_spaces
        sub_embedding_dims[-1] = embedding_dim - (embedding_dim//num_spaces) * (num_spaces-1)

        # when the embedding layer set to have padding index, we set dummy weight to embedding,
        # and allocate the last sub-embedding to padding index.
        embeddings = nn.ModuleList()
        padding_idx = kwargs.get("padding_idx", 1)
        for i in range(num_spaces):
            if "padding_idx" in kwargs:
                kwargs.update({"padding_idx": num_embedding_list[i]})
                embeddings.append(nn.Embedding(num_embedding_list[i] + 1, sub_embedding_dims[i], **kwargs))
            else:
                embeddings.append(nn.Embedding(num_embedding_list[i], sub_embedding_dims[i], **kwargs))

        self.embeddings = embeddings
        # all sub-embedding layer initialized independently
        for embedding in self.embeddings:
            embedding.weight.data.uniform_(-0.1, 0.1)

        mapper = self.set_mapper()
        if "padding_idx" in kwargs:
            for i in range(num_spaces):
                mapper[padding_idx, i] = num_embedding_list[i]
        self.register_buffer("mapper", mapper)

    def set_mapper(self) -> Tensor:
        mapper = torch.zeros(self.num_embeddings, self.num_spaces).long()
        idx = torch.arange(self.num_embeddings)

        for space in range(self.num_spaces):
            mapper[:, space] = torch.remainder(idx, self.n).long()
            idx = torch.div(idx, self.n)
        return mapper

    def forward(self, ipt: Tensor) -> Tensor:
        embedding_idx = []

        # lookup each sub-embedding and concatenate them
        for space in range(self.num_spaces):
            space_idx = self.mapper[:, space]
            embedding_idx.append(torch.nn.functional.embedding(ipt, space_idx))
        output = torch.cat([emb(idx) for emb, idx in zip(self.embeddings, embedding_idx)], dim=-1)
        return output


class TiledLinear(nn.Linear):

    def __init__(self, input_embeddings, *args, **kwargs):
        """
        This layer is used for tie-weight in sub-embedding
        :param input_embeddings: nn.Embedding or SubEmbedding layer
        """
        in_features = input_embeddings.embedding_dim
        out_features = input_embeddings.num_embeddings
        super(TiledLinear, self).__init__(in_features, out_features, *args, **kwargs)
        self.input_embeddings = input_embeddings
        if isinstance(input_embeddings, nn.Embedding):
            self.weight = input_embeddings.weight
        else:
            self.tile_weight()

    def tile_weight(self):
        assert isinstance(self.input_embeddings, SubEmbedding)
        n = self.input_embeddings.n
        idx = torch.arange(self.input_embeddings.num_embeddings)

        embedding_idx = []
        for space in range(self.input_embeddings.num_spaces):
            embedding_idx.append(torch.remainder(idx, n).long())
            idx = torch.div(idx, n)

        embeddings = self.input_embeddings.embeddings
        weight = torch.cat([emb.weight[idx] for emb, idx in zip(embeddings, embedding_idx)], dim=-1)
        return weight

    def forward(self, ipt: Tensor) -> Tensor:
        # update the weight each train step
        if isinstance(self.input_embeddings, SubEmbedding):
            weight = self.tile_weight()
        else:
            weight = self.weight
        return torch.nn.functional.linear(ipt, weight, self.bias)
