import torch.nn.functional as F
from torch import nn
import numpy as np
import torch
from distributions import bijector, MADE, NLSq, Logistic, Mixture
from distributions.flow import NVPTransform
from actnorm import ActNorm, ActNormFlow
import argmax_flow_encoding
from utils import assert_close


def soft_clamp(x, min=None, max=None):
    x_ = x
    if max is not None:
        x_ = max - F.softplus(max - x_)
    if min is not None:
        x_ = F.softplus(x_ - min) + min
    return x_

class TruncatedLogistic(nn.Module):
    def __init__(self):
        super(TruncatedLogistic, self).__init__()
        self.base = Logistic()
        self.register_buffer('log_2', torch.tensor(np.log(2)))
        self.register_buffer('width_bias', torch.tensor(4.))

    def params2bounds(self, trunc_params):
        log_uni_lr = torch.log_softmax(trunc_params[..., :-1], dim=-1)
        log_uni_centre = log_uni_lr[..., 0]
        log_uni_width = (F.logsigmoid(trunc_params[..., -1] + self.width_bias) +
                         torch.min(log_uni_lr, dim=-1)[0] + self.log_2)

        if DEBUG and False:
            centre = torch.exp(log_uni_centre)
            width = torch.exp(log_uni_width)
            start = centre - 0.5 * width
            end = centre + 0.5 * width
            print((start.mean().item(), start.std().item()),
                  (end.mean().item(), end.std().item()))
        return log_uni_centre, log_uni_width


    def log_probability(self, z: torch.Tensor,
                        trunc_params: torch.Tensor,
                        return_zero_mask=False):
        log_uni_centre, log_uni_width = self.params2bounds(trunc_params)
        p = self.base.cdf(z)
        centre = torch.exp(log_uni_centre)
        p0 = (p - centre) * torch.exp(-log_uni_width) + 0.5
        no_density = torch.any((p0 <= 0 - 1e-6) | (p0 >= 1 + 1e-6), dim=-1)
        # some imprecision when inversion happens, 1e-6 allows some tolerance
        # on either side
        if DEBUG and False:
            print('postsigmoid')
            print("u", p)
        log_p_z_ = self.base.log_probability(z) - log_uni_width.sum(dim=-1)
        log_p_z = log_p_z_.masked_fill(no_density, -64)
        if return_zero_mask:
            return log_p_z_, no_density
        else:
            return log_p_z

    def sample(self, trunc_params, eps=None):
        if eps is None:
            eps = torch.rand_like(trunc_params[..., 0])
        log_uni_centre, log_uni_width = self.params2bounds(trunc_params)
        u = (torch.exp(log_uni_centre) +
               torch.exp(log_uni_width) * (eps - 0.5))
        if DEBUG and False:
            print('U sample', u)
        z, log_p = self.base.sample(eps=u)
        log_p = log_p - log_uni_width.sum(-1)
        return z ,log_p


class StochasticEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, hidden_dim,
                 unk_idx=None, unk_encoder=False,
                 pad_idx=None, flow=False, truncated=True):
        super(StochasticEmbedding, self).__init__()

        self.z_dim = embedding_dim
        self.embedding_dim = hidden_dim
        self.params_dim = 2 * self.z_dim + 3 * self.z_dim
        self.embedding = nn.Embedding(num_embeddings, self.params_dim,
                                      padding_idx=pad_idx)
        nn.init.xavier_uniform_(self.embedding.weight)


        self.scale_shift = bijector.ScaleShift()
        if flow:
            def nvp():
                return NVPTransform(self.z_dim // 2, self.h_dim, self.z_dim // 2,
                                    hidden_size=self.h_dim)
            self.flow = bijector.RealNVP(transform=nvp(), transform_2=nvp())
        else:
            self.flow = None

        self._x_prior = nn.Parameter(torch.zeros(num_embeddings))
        self.truncated = truncated

        if truncated:
            self.base_distribution = TruncatedLogistic()
        else:
            self.base_distribution = Logistic()

        self.unk_idx = unk_idx
        self.unk_encoder = unk_encoder

        if unk_encoder:
            self.register_buffer('unk_bias', torch.tensor(0.))
            self.unk_encoder = RNNLM(self.z_dim, hidden_dim)
            self.unk_transform = nn.Linear(self.embedding_dim, self.params_dim)
            nn.init.zeros_(self.unk_transform.bias)
            nn.init.xavier_uniform_(self.unk_transform.weight)
            self.unk_encoder_b = RNNLM(self.z_dim, hidden_dim)

    @property
    def x_prior(self):
        return F.log_softmax(self._x_prior, dim=-1)

    def get_parameters(self, emb: torch.Tensor, unk_locations=None):
        params = emb
        b, log_w, trunc_params = params.split((
            self.z_dim, self.z_dim, self.z_dim * 3
        ), dim=-1)

        if self.unk_encoder and unk_locations is not None:
            unk_emb = self.unk_encoder(b)
            params = self.unk_transform(unk_emb[unk_locations])
            b_, flow_params_, log_w_, trunc_params_ = params.split((
                self.z_dim, self.h_dim,
                self.z_dim, self.z_dim * 3
            ), dim=-1)
            u = unk_locations[..., None]
            b = b.masked_scatter(u, b_)
            flow_params = flow_params.masked_scatter(u, flow_params_)
            log_w = log_w.masked_scatter(u, log_w_)
            trunc_params = trunc_params.masked_scatter(u, trunc_params_)

        b = 4. * b
        log_w = soft_clamp(log_w, min=-3., max=10)
        trunc_params = trunc_params.view(*(trunc_params.size()[:-1]), -1, 3)
        return log_w, b, None, trunc_params

    def encode(self, x: torch.LongTensor, eps=None):
        if self.unk_encoder:
            unk_locations = x == self.unk_idx
        else:
            unk_locations = None
        log_w, b, flow_params, trunc_params = self.get_parameters(self.embedding(x),
                                                                  unk_locations=unk_locations)
        z0, log_p_z0 = self.base_distribution.sample(trunc_params, eps=eps)
        z1, ldji1 = self.scale_shift.forward_and_invlogdet(z0, log_w, b)

        if self.flow is not None:
            z2, ldji2 = self.flow.forward_and_invlogdet(z1, flow_params)
        else:
            z2, ldji2 = z1, 0.
        log_q_z2 = log_p_z0 + ldji1 + ldji2
        return z2, log_q_z2

    def _decode(self, z: torch.Tensor):
        log_w, b, flow_params, trunc_params = \
            self.get_parameters(self.embedding.weight)

        z2 = z[..., None, :]
        if self.flow is not None:
            z1, ldji2 = self.flow.inverse_and_invlogdet(z2, flow_params)
        else:
            z1, ldji2 = z2, 0.
        z0, ldji1 = self.scale_shift.inverse_and_invlogdet(z1, log_w, b)
        log_q_all_, zeros_mask = self.base_distribution.log_probability(
            z0, trunc_params, return_zero_mask=True
        )

        log_q_all = (log_q_all_ + ldji1 + ldji2).masked_fill(zeros_mask, -64.)
        if self.unk_encoder:
            logits = log_q_all.clone()
            logits[..., self.unk_idx] = self.unk_bias
        else:
            logits = log_q_all

        if DEBUG:
            active_freq = (~zeros_mask).float().sum(-1).mean()
            print("active:", active_freq)
        return logits, zeros_mask


    def decode(self, z: torch.Tensor,
               x_prior: torch.Tensor=None,
               x: torch.LongTensor=None,
               log_q_z_x: torch.Tensor=None,
               return_zeros_mask=False):

        log_q_all, zeros_mask = self._decode(z)
        if x_prior is None:
            x_prior = self.x_prior

        x_prior = x_prior.masked_fill(zeros_mask, -64.)
        logits = log_q_all + x_prior

        if x is None:
            # If no label given, return entire distribution.
            log_p_x_z = F.log_softmax(logits, dim=-1)
        else:
            # If label given return probability of label.
            if log_q_z_x is not None:
                if self.unk_encoder:
                    log_q_z_x = log_q_z_x.clone()
                    log_q_z_x[x == self.unk_idx] = self.unk_bias
                x_flat = x.flatten()
                idxs = torch.arange(x_flat.size(0),
                                    dtype=torch.long,
                                    device=x.device)

                if DEBUG:
                    x_flat = x.flatten()
                    idxs = torch.arange(x_flat.size(0), dtype=torch.long, device=x.device)
                    log_q_z_ = log_q_all.flatten(0, -2)[idxs, x_flat].view(x.size())
                    assert_close(log_q_z_x, log_q_z_)
                    print("log_q_z before and after inversion are close")
                # If log q(z|x) is given, use that as numerator.
                x_prior_ = x_prior.flatten(0, -2)[idxs, x_flat].view(x.size())

                log_numer = log_q_z_x + x_prior_
                logits_flat = logits.flatten(0, -2)
                logits_flat[idxs, x_flat] = log_numer.flatten()
                logits = logits_flat.view(logits.size())

                log_denom = torch.logsumexp(logits, dim=-1)
                log_p_x_z = log_numer - log_denom

            else:
                # Otherwise just calculate the log normalisation.
                log_p_x_z = -F.cross_entropy(logits.flatten(0, -2), x.flatten(),
                                             reduction='none').view(x.size())
        if return_zeros_mask:
            return log_p_x_z, zeros_mask.all(-1)
        else:
            return log_p_x_z



    def forward(self, x: torch.LongTensor, eps=None):
        z, log_q_z = self.encode(x, eps=eps)
        log_p_x_z = self.decode(z, x=x, log_q_z_x=log_q_z)
        return z, log_q_z, log_p_x_z


class RNNLM(nn.Module):
    def __init__(self, in_dim, h_dim, num_layers=3,
                 dropout=0.1, word_dropout=0.25):
        super(RNNLM, self).__init__()
        self.in_dim = in_dim
        self.h_dim = h_dim
        self.in2h = nn.Sequential(
            nn.Linear(in_dim, h_dim),
            nn.Tanh(),
            nn.Dropout(dropout)
        )
        self.rnn = nn.LSTM(h_dim, h_dim, num_layers=num_layers,
                           batch_first=True,
                           dropout=dropout)
        self.dropout = nn.Dropout(dropout)
        self.init_token = nn.Parameter(torch.zeros(self.in_dim))
        self.word_drop = word_dropout

    def drop_words(self, x):
        if self.training:
            eps = torch.rand_like(x[..., 0])
            drop = eps < self.word_drop
            # x = x.masked_fill(drop[..., None], 0.)
            x = x.clone()
            x[drop] = torch.rand_like(x[drop])
        return x



    def forward(self, x):

        rnn_in = torch.cat((
            self.init_token[None, None, :].expand(x.size(0), -1, -1),
            self.drop_words(x[:, :-1])
        ), dim=1)
        rnn_out, _ = self.rnn(self.in2h(rnn_in))
        rnn_out = self.dropout(rnn_out)
        return rnn_out

class StochasticEmbeddingLM(nn.Module):
    def __init__(self, vocab, unk_idx, pad_idx,
                 argmax_flow, truncated,
                 h_dim, z_dim,
                 prior_layers, prior_made_upscale,
                 prior_word_drop, prior_dropout,
                 conditional_word_prior,
                 unk_encoder):
        super(StochasticEmbeddingLM, self).__init__()
        self.unk_idx = unk_idx
        self.h_dim = h_dim
        self.end_idx = vocab
        self.pad_idx = pad_idx
        self.word_drop = prior_word_drop
        self.truncated = truncated

        if not argmax_flow:
            self.embed = StochasticEmbedding(num_embeddings=vocab + 1,
                                             embedding_dim=z_dim,
                                             hidden_dim=h_dim,
                                             unk_idx=self.unk_idx,
                                             unk_encoder=unk_encoder,
                                             pad_idx=self.pad_idx,
                                             truncated=self.truncated)
            self.z_dim = z_dim
        else:
            self.embed = argmax_flow_encoding.StochasticEmbedding(
                num_embeddings=vocab + 1,
                embedding_dim=z_dim,
                hidden_dim=h_dim,
            )
            self.z_dim = int(self.embed.K.item())
        self.actnorm_flow = ActNormFlow(self.z_dim)
        self.prior_rnn = RNNLM(self.z_dim, self.h_dim, num_layers=prior_layers, 
                               word_dropout=prior_word_drop)
        self.prior_made = MADE(self.z_dim, self.h_dim,
                               activation=lambda: nn.Sequential(
                                   ActNorm(self.z_dim * prior_made_upscale, frozen=True),
                                   nn.GELU(),
                                   nn.Dropout(prior_dropout)
                               ),
                               base_distribution=Mixture(
                                   input_size=self.z_dim,
                                   components=2,
                                   base_distribution=Logistic()
                               ),
                               params_size=2 * 3,
                               hidden_upscale=prior_made_upscale)

        if conditional_word_prior:
            self.prior_w = nn.Sequential(
                nn.Linear(self.h_dim, vocab + 1),
                nn.LogSoftmax(dim=-1)
            )
            nn.init.zeros_(self.prior_w[-2].weight)
            nn.init.zeros_(self.prior_w[-2].bias)
        else:
            self.prior_w = None

        if unk_encoder:
            self.embed.unk_encoder = self.prior_rnn

    def log_prior(self, rnn_out, z):
        log_p_z = self.prior_made(z, rnn_out)
        if self.prior_w:
            log_prior_w = self.prior_w(rnn_out)
            return log_p_z, log_prior_w
        else:
            return log_p_z, None

    def log_probability(self, x, k=1):
        _x = x
        x = _x[:, None].expand(-1, k, -1).flatten(0, 1)
        nll, kl = self.forward(x)
        log_ratios = -(nll.sum(1) + kl.sum(1)).view(_x.size(0), k)
        log_k = torch.log(torch.tensor(k, dtype=torch.float, device=x.device))
        log_p_x = torch.logsumexp(log_ratios, dim=1) - log_k
        return log_p_x

    def append_end(self, x: torch.LongTensor):
        lengths = (x != self.pad_idx).sum(1)
        x = F.pad(x, (0, 1), 'constant', self.pad_idx)
        idxs = torch.arange(x.size(0), dtype=torch.long, device=x.device)
        x[idxs, lengths] = self.end_idx
        return x

    def drop_words(self, x):
        x = x.clone()
        mask = torch.rand_like(x, dtype=torch.float) < self.word_drop
        mask = mask & (x != self.pad_idx)
        x[mask] = self.unk_idx
        return x

    def forward(self, x: torch.LongTensor):
        x = self.append_end(x)
        pad_mask = x == self.pad_idx

        # Encode
        z0, log_q_z0 = self.embed.encode(x)
        z1, ldji1 = self.actnorm_flow.forward_and_invlogdet(z0)
        z, log_q_z = z1, log_q_z0 + ldji1

        # Compute log p
        rnn_out = self.prior_rnn(z)
        log_p_z, log_prior_w = self.log_prior(rnn_out, z)

        # Decode (kinda)
        log_p_x_z = self.embed.decode(
            z0, x_prior=log_prior_w, log_q_z_x=log_q_z0, x=x)

        kl = (log_q_z - log_p_z).masked_fill(pad_mask, 0.)
        nll = (-log_p_x_z).masked_fill(pad_mask, 0.)

        if DEBUG:
            # Compare by doing crossentropy
            log_p_x_z_ = self.embed.decode(z0, log_q_z_x=log_q_z0)
            log_p_x_z_ = -F.cross_entropy(log_p_x_z_.flatten(0, -2),
                                          x.flatten(),
                                          reduction='none').view(x.size())
            assert_close(log_p_x_z, log_p_x_z_)

            # Compare without numerator trick
            log_p_x_z_ = self.embed.decode(z0, x=x, log_q_z_x=log_q_z0)
            assert_close(log_p_x_z, log_p_x_z_)

            print("log_q", log_q_z.mean().item(),
                  "log_p", log_p_z.mean().item())
            print("nll", nll.mean().item())

        return nll, kl


if __name__ == "__main__":
    DEBUG = True
    """
    se = StochasticEmbedding(
        num_embeddings=4,
        embedding_dim=2,
        hidden_dim=256, # unk_idx=0
    )
    x = torch.arange(4, dtype=torch.long)[:, None]
    print("x", x)
    z, log_q_z, log_p_x_z = se(x, eps=torch.zeros((1,1,2)) + 1)
    z, log_q_z, log_p_x_z = se(x, eps=torch.zeros((1,1,2)))
    print("log_q_z", log_q_z)
    print("log_p_x_z",log_p_x_z)
    DEBUG = False
    """
    model = StochasticEmbeddingLM(vocab=100,
                                  unk_idx=0, pad_idx=-1,
                                  h_dim=20, z_dim=5,
                                  prior_layers=1,
                                  prior_made_upscale=2,
                                  prior_word_drop=0.,
                                  prior_dropout=0.,
                                  conditional_word_prior=False,
                                  unk_encoder=False, argmax_flow=False)
    x = torch.randint(0, 100, size=(1, 1))
    log_p_x = model.log_probability(x, k=1)
    print(log_p_x.size())
