from torch import nn
import torch
from torch.nn import functional as F
import numpy as np
from distributions import bijector, Logistic


def integer_to_base(idx_tensor, dims, base=2):
    '''
    Encodes index tensor to a Cartesian product representation.
    Args:
        idx_tensor (LongTensor): An index tensor, shape (...), to be encoded.
        base (int): The base to use for encoding.
        dims (int): The number of dimensions to use for encoding.
    Returns:
        LongTensor: The encoded tensor, shape (..., dims).
    '''
    powers = base ** torch.arange(dims - 1, -1, -1, device=idx_tensor.device)
    floored = idx_tensor[..., None] // powers
    remainder = floored % base
    base_tensor = remainder
    return base_tensor

class StochasticEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_dim,
                 unk_idx=None, pad_idx=None, flow=False, truncated=True):
        super(StochasticEmbedding, self).__init__()
        self.unk_idx = unk_idx
        self.register_buffer(
            "K", torch.tensor(int(np.ceil(np.log2(num_embeddings)))))
        self.params_dim = self.K * 2
        self.embedding = nn.Embedding(num_embeddings, self.params_dim,
                                      padding_idx=pad_idx)
        nn.init.xavier_uniform_(self.embedding.weight)
        self.base_distribution = Logistic()
        self.scale_shift = bijector.ScaleShift()

    def get_parameters(self, emb: torch.Tensor):
        params = emb
        b, log_w = params.split((self.K, self.K), dim=-1)
        return b, log_w

    def encode(self, x: torch.LongTensor, eps=None):
        log_w, b = self.get_parameters(self.embedding(x))
        binary_x = integer_to_base(x, self.K)
        sign = binary_x * 2 - 1
        z0, log_q_z0 = self.base_distribution.sample(
            eps=torch.rand_like(binary_x, dtype=torch.float))
        z1, ldji1 = self.scale_shift.forward_and_invlogdet(z0, log_w, b)
        z2, ldji2 = F.softplus(z1), -F.logsigmoid(z1)
        z2 = z2 * sign
        return z2, log_q_z0 + ldji1 + ldji2.sum(-1)

    def decode(self, z: torch.Tensor,
               x_prior: torch.Tensor=None,
               x: torch.LongTensor=None,
               log_q_z_x: torch.Tensor=None,
               return_zeros_mask=False):
        return torch.zeros_like(z[..., 0])


    def forward(self, x: torch.LongTensor, eps=None):
        z, log_q_z = self.encode(x, eps=eps)
        log_p_x_z = torch.zeros_like(log_q_z)
        return z, log_q_z, log_p_x_z

if __name__ == "__main__":
    categories = 50
    se = StochasticEmbedding(categories, -1, -1)
    z, log_q_z = se.encode(torch.arange(categories, dtype=torch.long))
    print(z)
