import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import numpy as np

sys.path.append("../../")

from general.mutils import get_param_val, one_hot
from layers.flows.flow_layer import FlowLayer
from layers.flows.permutation_layers import InvertibleConv
from layers.flows.activation_normalization import ExtActNormFlow
from layers.flows.coupling_layer import CouplingLayer
from layers.flows.distributions import LogisticDistribution
from layers.networks.help_layers import SimpleLinearLayer, LinearNet
from layers.categorical_encoding.decoder import create_decoder, create_embed_layer
import traceback

# TODO
# Check if they're computing ELBO or estimating log p(x)

corr_f = torch.tensor(1.81)
DEBUG = False
# torch.autograd.set_detect_anomaly(DEBUG)


class NoSupportException(Exception):
    def __init__(self, invalid_mask):
        self.invalid_mask = invalid_mask
        super().__init__("ZERO components activated. Fell into sea.")

def sigmoid_inv(p):
    return -torch.log(
        torch.clamp(torch.reciprocal(torch.clamp(p, min=1e-8)) - 1, min=1e-8)
    )

def params2bounds(trunc_params):
    log_uni_lr = torch.log_softmax(trunc_params[..., :-1], dim=-1)
    log_uni_centre = log_uni_lr[..., 0]
    log_uni_width = (F.logsigmoid(trunc_params[..., -1] + 4.) +
                     torch.min(log_uni_lr, dim=-1)[0] + np.log(2))
    if DEBUG:
        centre = torch.exp(log_uni_centre)
        width = torch.exp(log_uni_width)
        start = centre - 0.5 * width
        end = centre + 0.5 * width
        print((start.mean().item(), start.std().item()),
              (end.mean().item(), end.std().item()))


    return log_uni_centre, log_uni_width


def density(z_0):
    d = F.logsigmoid(z_0) + F.logsigmoid(-z_0)
    return d

def truncated_density(z, trunc_params,
                      log_uni_start=None,
                      log_uni_end=None,
                      log_uni_width=None,
                      return_zero_mask=False):
    if log_uni_start is None or log_uni_end is None or log_uni_width is None:
        log_uni_centre, log_uni_width = params2bounds(trunc_params)
    log_p_z_ = (density(z * corr_f) - log_uni_width + torch.log(corr_f)).sum(dim=-1)
    p = torch.sigmoid(z)
    centre = torch.exp(log_uni_centre)
    width = torch.exp(log_uni_width)
    start = centre - 0.5 * width
    end = centre + 0.5 * width
    no_density = torch.any((p < start) | (p > end), dim=-1)
    log_p_z = log_p_z_.masked_fill(no_density, -64)
    if return_zero_mask:
        return log_p_z, no_density
    else:
        return log_p_z

def sample_trunc_log(trunc_params):
    log_uni_centre, log_uni_width = params2bounds(trunc_params)
    z_0 = (torch.exp(log_uni_centre) +
           torch.exp(log_uni_width) * (torch.rand_like(trunc_params[..., 0]) - 0.5))
    z_1 = sigmoid_inv(z_0)
    z_2 = z_1 / corr_f
    # log_dens = truncated_density(z_2, trunc_params,
    #                              log_uni_start,
    #                              log_uni_end,
    #                              log_uni_width)
    log_dens = (density(z_1) - log_uni_width + torch.log(corr_f)).sum(-1)
    # print(log_dens)
    # print(log_dens_)
    return z_2, log_dens
class LinearCategoricalEncoding(FlowLayer):
    """
    Class for implementing the mixture model and linear flow encoding scheme of Categorical Normalizing Flows.
    A mixture model can be achieved by using a single activation normalization layer as "linear flow".
    Hence, this class combines both encoding schemes.
    """

    def __init__(self, num_dimensions, flow_config,
                 dataset_class=None,
                 vocab=None, vocab_size=-1,
                 use_decoder=False, decoder_config=None,
                 default_embed_layer_dims=64,
                 category_prior=None,
                 **kwargs):
        super().__init__()
        self.use_decoder = use_decoder
        self.dataset_class = dataset_class
        self.D = num_dimensions
        print("Using truncated logistic")
        print("Using truncated logistic")
        print("Using truncated logistic")
        self.embed_layer, self.vocab_size = create_embed_layer(vocab, vocab_size, default_embed_layer_dims)
        self.bounds_emb = nn.Embedding(vocab_size, self.D * 3)
        torch.nn.init.zeros_(self.bounds_emb.weight)

        self.num_categories = self.vocab_size

        self.prior_distribution = LogisticDistribution(mu=0.0, sigma=1.0)  # Prior distribution in encoding flows
        self.flow_layers = _create_flows(num_dims=num_dimensions,
                                         embed_dims=self.embed_layer.weight.shape[1],
                                         config=flow_config)
        # Create decoder if needed
        if self.use_decoder:
            self.decoder = create_decoder(num_categories=self.vocab_size,
                                          num_dims=self.D,
                                          config=decoder_config)

        # Prior over the categories. If not given, a uniform prior is assumed
        if category_prior is None:
            category_prior = torch.zeros(self.vocab_size, dtype=torch.float32)
        else:
            assert category_prior.shape[
                       0] == self.num_categories, "[!] ERROR: Category prior needs to be of size [%i] but is %s" % (
            self.num_categories, str(category_prior.shape))
            if isinstance(category_prior, np.ndarray):
                category_prior = torch.from_numpy(category_prior)
        # self.category_prior_ = nn.Parameter(category_prior)
        self.register_buffer("category_prior", F.log_softmax(category_prior, dim=-1))

    def z_bounds(self, z_categ):
        z_trunc_params = self.bounds_emb(z_categ)
        z_trunc_params = z_trunc_params.view(*(z_trunc_params.size()[:-1]),
                                             self.bounds_emb.embedding_dim // 3, 3)
        return z_trunc_params

    def forward(self, z, ldj=None, reverse=False, beta=1, delta=0.0, channel_padding_mask=None, **kwargs):
        ## We reshape z into [batch, 1, ...] as every categorical variable is considered to be independent.
        batch_size, seq_length = z.size(0), z.size(1)
        z = z.reshape((batch_size * seq_length, 1) + z.shape[2:])
        if channel_padding_mask is not None:
            channel_padding_mask = channel_padding_mask.reshape(batch_size * seq_length, 1, -1)
        else:
            channel_padding_mask = z.new_ones((batch_size * seq_length, 1, 1), dtype=torch.float32)

        ldj_loc = z.new_zeros(z.size(0), dtype=torch.float32)
        detailed_ldj = {}

        if not reverse:
            # z is of shape [Batch, SeqLength]
            z_categ = z  # Renaming here for better readability (what is discrete and what is continuous)
            ## 1.) Forward pass of current token flow
            z_trunc_params = self.z_bounds(z_categ)
            z_cont, init_log_p = sample_trunc_log(z_trunc_params)
            init_log_p = init_log_p[..., 0]
            z_cont, ldj_forward = self._flow_forward(z_cont, z_categ, reverse=False)

            if not self.use_decoder:
                class_prior_log = torch.take(self.category_prior, z_categ.squeeze(dim=-1))
                log_point_prob = init_log_p - ldj_forward + class_prior_log
                class_prob_log = self._calculate_true_posterior(z_cont, z_categ, log_point_prob)
            else:
                class_prob_log = self._decoder_forward(z_cont, z_categ)

            ## 3.) Calculate final LDJ
            ldj_loc = (beta * class_prob_log - (init_log_p - ldj_forward))
            ldj_loc = ldj_loc * channel_padding_mask.squeeze()
            z_cont = z_cont * channel_padding_mask
            z_out = z_cont

            ## 4.) Statistics for debugging/monotoring
            if self.training:
                with torch.no_grad():
                    z_min = z_out.min()
                    z_max = z_out.max()
                    z_std = z_out.view(-1, z_out.shape[-1]).std(0).mean()
                    channel_padding_mask = channel_padding_mask.squeeze()
                    detailed_ldj = {"avg_token_prob": (
                                                                  class_prob_log.exp() * channel_padding_mask).sum() / channel_padding_mask.sum(),
                                    "avg_token_bpd": -(
                                                class_prob_log * channel_padding_mask).sum() / channel_padding_mask.sum() * np.log2(
                                        np.exp(1)),
                                    "z_min": z_min,
                                    "z_max": z_max,
                                    "z_std": z_std}
                    detailed_ldj = {key: val.detach() for key, val in detailed_ldj.items()}

        else:
            # z is of shape [Batch * seq_len, 1, D]
            assert z.size(
                -1) == self.D, "[!] ERROR in categorical decoding: Input must have %i latent dimensions but got %i" % (
            self.D, z.shape[-1])

            class_prior_log = self.category_prior[None, None, :]
            z_cont = z

            if not self.use_decoder:
                z_out = self._posterior_sample(z_cont)
            else:
                z_out = self._decoder_sample(z_cont)

        # Reshape output back to original shape
        if not reverse:
            z_out = z_out.reshape(batch_size, seq_length, -1)
        else:
            z_out = z_out.reshape(batch_size, seq_length)
        ldj_loc = ldj_loc.reshape(batch_size, seq_length).sum(dim=-1)

        # Add LDJ
        if ldj is not None:
            ldj = ldj + ldj_loc
        else:
            ldj = ldj_loc

        return z_out, ldj, detailed_ldj

    def _flow_forward(self, z_cont, z_categ, reverse, **kwargs):
        ldj = z_cont.new_zeros(z_cont.size(0), dtype=torch.float32)
        embed_features = self.embed_layer(z_categ)

        for flow in (self.flow_layers if not reverse else reversed(self.flow_layers)):
            z_cont, ldj = flow(z_cont, ldj, ext_input=embed_features, reverse=reverse, **kwargs)

        return z_cont, ldj

    def _decoder_forward(self, z_cont, z_categ, **kwargs):
        ## Applies the deocder on every continuous variable independently and return probability of GT class
        class_prob_log = self.decoder(z_cont)
        class_prob_log = class_prob_log.gather(dim=-1, index=z_categ.view(-1, 1))
        return class_prob_log

    def _calculate_true_posterior(self, z_cont, z_categ, log_point_prob, **kwargs):
        ## Run backward pass of *all* class-conditional flows
        z_back_in = z_cont.expand(-1, self.num_categories, -1).reshape(-1, 1, z_cont.size(2))
        sample_categ = torch.arange(self.num_categories, dtype=torch.long).to(z_cont.device)
        sample_categ = sample_categ[None, :].expand(z_categ.size(0), -1).reshape(-1, 1)

        z_back, ldj_backward = self._flow_forward(z_back_in, sample_categ, reverse=True, **kwargs)
        z_trunc_params = self.z_bounds(sample_categ)
        # _, _, _, _, _, log_uni_width = params2bounds(z_cont_a, z_cont_b)
        # mask = ~torch.all((z_start < z_back) & (z_back < z_end), dim=-1)
        # print(z_back.size())
        back_log_p, zero_mask = truncated_density(z_back, z_trunc_params,
                                                  return_zero_mask=True)

        back_log_p = back_log_p[..., 0]
        zero_mask = zero_mask[..., 0]
        # print(back_log_p.size())
        # back_log_p = torch.sum(density(z_back) - log_uni_width, dim=-1).masked_fill(mask, -64).sum(dim=-1)
        # back_log_p = self.prior_distribution.log_prob(z_back).sum(dim=[1, 2])
        # print(back_log_p.size())

        # Calculate the denominator (sum of probabilities of all classes)
        flow_log_prob = back_log_p + ldj_backward
        log_prob_denominator = flow_log_prob.view(z_cont.size(0), self.num_categories) + self.category_prior[None, :]
        if DEBUG and not self.training:
            activated = (~zero_mask).view(z_cont.size(0), self.num_categories).sum(-1)
            print("Activated max:", activated.max())
            print("Activated min:", activated.min())


        if DEBUG and not self.training :
            traceback.print_stack()

        # print(log_point_prob)
        # print(log_prob_denominator[torch.arange(z_cont.size(0)),
        #                             z_categ[...,0]])

        # Replace log_prob of original class with forward probability
        # This improves stability and prevents the model to exploit numerical errors during inverting the flows
        orig_class_mask = one_hot(z_categ.squeeze(), num_classes=log_prob_denominator.size(1))
        log_prob_denominator = log_prob_denominator * (1 - orig_class_mask) + log_point_prob.unsqueeze(dim=-1) * orig_class_mask
        # Denominator is the sum of probability -> turn log to exp, and back to log
        # log_denominator = torch.logsumexp(F.pad(log_prob_denominator, (1, 0), 'constant', 0), dim=-1)
        log_denominator = torch.logsumexp(log_prob_denominator, dim=-1)
        # Combine nominator and denominator for final prob log
        class_prob_log = (log_point_prob - log_denominator)
        # *** End *** 
        # class_prob_log = -F.cross_entropy(log_prob_denominator, z_categ[..., 0],
        #                                   reduction='none')
        if DEBUG:
            print(class_prob_log.mean())
        return class_prob_log

    def _decoder_sample(self, z_cont, **kwargs):
        ## Sampling from decoder by taking the argmax.
        # We could also sample from the probabilities, however experienced that the argmax gives more stable results.
        # Presumably because the decoder has also seen values sampled from the encoding distributions and not anywhere besides that.
        return self.decoder(z_cont).argmax(dim=-1)

    def _posterior_sample(self, z_cont, **kwargs):
        ## Run backward pass of *all* class-conditional flows
        z_back_in = z_cont.expand(-1, self.num_categories, -1).reshape(-1, 1, z_cont.size(2))
        sample_categ = torch.arange(self.num_categories, dtype=torch.long).to(z_cont.device)
        sample_categ = sample_categ[None, :].expand(z_cont.size(0), -1).reshape(-1, 1)
        z_back, ldj_backward = self._flow_forward(z_back_in, sample_categ, reverse=True, **kwargs)
        z_trunc_params = self.z_bounds(sample_categ)
        back_log_p, zero_mask = truncated_density(z_back, z_trunc_params,
                                                  return_zero_mask=True)
        # back_log_p = self.prior_distribution.log_prob(z_back).sum(dim=[1, 2])
        back_log_p = back_log_p[..., 0]
        zero_mask = zero_mask[..., 0]
        activated = (~zero_mask).view(z_cont.size(0), self.num_categories)

        if DEBUG and not self.training:
            activated_count = (~zero_mask).view(z_cont.size(0), self.num_categories).sum(-1)
            print("Activated max:", activated_count.max())
            print("Activated min:", activated_count.min())

        if (~activated).all(-1).any():
            # print(activated.size())
            raise NoSupportException(invalid_mask=(~activated).all(-1))

        ## Calculate the log probability for each class
        flow_log_prob = back_log_p + ldj_backward
        log_prob_denominator = flow_log_prob.view(z_cont.size(0), self.num_categories) + self.category_prior[None, :]
        return log_prob_denominator.argmax(dim=-1)

    def info(self):
        s = ""
        if len(self.flow_layers) > 1:
            s += "Linear Encodings of categories, with %i dimensions and %i flows.\n" % (self.D, len(self.flow_layers))
        else:
            s += "Mixture model encoding of categories with %i dimensions\n" % (self.D)
        # s += "-> Prior distribution: %s\n" % self.prior_distribution.info()
        if self.use_decoder:
            s += "-> Decoder network: %s\n" % self.decoder.info()
        s += "\n".join(
            ["-> [%i] " % (flow_index + 1) + flow.info() for flow_index, flow in enumerate(self.flow_layers)])
        return s


def _create_flows(num_dims, embed_dims, config):
    num_flows = get_param_val(config, "num_flows", 0)
    num_hidden_layers = get_param_val(config, "hidden_layers", 2)
    hidden_size = get_param_val(config, "hidden_size", 256)

    # We apply a linear net in the coupling layers for linear flows
    block_type_name = "LinearNet"
    block_fun_coup = lambda c_out: LinearNet(c_in=num_dims,
                                             c_out=c_out,
                                             num_layers=num_hidden_layers,
                                             hidden_size=hidden_size,
                                             ext_input_dims=embed_dims)

    # For the activation normalization, we map an embedding to scaling and bias with a single layer
    block_fun_actn = lambda: SimpleLinearLayer(c_in=embed_dims, c_out=2 * num_dims, data_init=True)

    permut_layer = lambda flow_index: InvertibleConv(c_in=num_dims)
    actnorm_layer = lambda flow_index: ExtActNormFlow(c_in=num_dims,
                                                      net=block_fun_actn())
    # We do not use mixture coupling layers here aas we need the inverse to be differentiable as well
    coupling_layer = lambda flow_index: CouplingLayer(c_in=num_dims,
                                                      mask=CouplingLayer.create_channel_mask(c_in=num_dims),
                                                      block_type=block_type_name,
                                                      model_func=block_fun_coup)

    flow_layers = []
    if num_flows == 0 or num_dims == 1:  # Num_flows == 0 => mixture model, num_dims == 1 => coupling layers have no effect
        flow_layers += [actnorm_layer(flow_index=0)]
    else:
        for flow_index in range(num_flows):
            flow_layers += [
                actnorm_layer(flow_index),
                permut_layer(flow_index),
                coupling_layer(flow_index)
            ]

    return nn.ModuleList(flow_layers)


if __name__ == '__main__':
    ## Example for using linear encoding
    torch.manual_seed(42)
    np.random.seed(42)

    batch_size, seq_len = 3, 6
    vocab_size, D = 4, 3
    flow_config = {
        "num_flows": 0,
        "num_hidden_layers": 1,
        "hidden_size": 128
    }

    categ_encod = LinearCategoricalEncoding(num_dimensions=D, flow_config=flow_config, vocab_size=vocab_size)
    print(categ_encod.info())
    rand_inp = torch.randint(high=vocab_size, size=(batch_size, seq_len), dtype=torch.long)
    z_out, ldj, detail_ldj = categ_encod(rand_inp)
    print("Z out", z_out)
    print("Detail ldj", detail_ldj)
