# ---------------------------
# _, _ -- 2019
# The University of _, The _ Institute
# contact: _, _
# ---------------------------
"""Defines ELBO loss functions
"""
import tensorflow as tf
import tensorflow_probability as tfp
from utils.tensor_utils import nunique_cols
import math
import os
import numpy as np

tfd = tfp.distributions


def estimate_entropies(logqz_i, N):
    """Computes the term:
        E_{p(x)} E_{q(z|x)} [-log q(z)]
    and
        E_{p(x)} E_{q(z_j|x)} [-log q(z_j)]
    where q(z) = 1/N sum_n=1^N q(z|x_n).
    Assumes samples are from q(z|x) for *all* x in the dataset.
    Assumes that q(z|x) is factorial ie. q(z|x) = prod_j q(z_j|x).
    Computes numerically stable NLL:
        - log q(z) = log N - logsumexp_n=1^N log q(z|x_n)
    Inputs:
    -------
        logqz_i (K, S) Variable
    """

    batch_size = tf.cast(tf.shape(logqz_i)[1], tf.float32)
    # --- p(n|vk) is calculated over full dataset size, i.e is 1/N
    # --- or p(n|vk) is approximated over sets of all batches, i.e is 1/M*N
    # logp_n = tf.log(batch_size * N)
    logp_n = tf.cond(
        batch_size >= N, lambda: tf.log(N), lambda: tf.log(batch_size * N))

    # --- Samples_per_batch is batch_size*n_samples
    # ---- Compute log q(z_i) ~= - log(MN) + logsumexp_m(q(z_i|x_m))
    entropies = -(tf.reduce_mean(tf.reduce_logsumexp(logqz_i, axis=1) - logp_n,
                                 axis=0))

    return entropies


def MIG(mi_normed):
    return tf.reduce_mean(mi_normed[:, 0] - mi_normed[:, 1])


def compute_metric_shapes(marginal_entropies, cond_entropies, factor_support):

    mutual_infos = marginal_entropies - cond_entropies

    # --- Clip negative MIs (can occur cause of estimation)
    mutual_infos = tf.sort(mutual_infos, axis=1, direction='DESCENDING')

    mi_normed = mutual_infos / tf.reshape(tf.log(factor_support), (-1, 1))

    # --- IF we had factor support = 1 then we get NANs eg. in dpsrites where only
    # --- 1 factor of var
    mi_normed = tf.where(tf.is_nan(mi_normed), tf.zeros_like(mi_normed),
                         mi_normed)

    metric = MIG(mi_normed)

    return metric


def assign_val(x, i, v):
    x[i] = v
    return x


def mutual_info_metric_shapes(logqz_i, factors_batch, latent_size,
                              dataset_size):

    N = float(dataset_size)
    # --- marginal entropies
    marginal_entropies = estimate_entropies(logqz_i, N)

    # --- Unique values for [shape, scale, orientation, pos x, pos y]
    # --- Here we estimate the factor entropy for the batch not the entire dataset
    # --- Column 0 corresponds to color, which as of yet is only white
    factor_support = tf.py_func(nunique_cols, [factors_batch], tf.float32)

    num_factors = tf.shape(factors_batch)[1]

    # --- If we are certain each batch will contain the whole range of values for each factor
    # factor_entropies = tf.constant([3, 6, 40, 32, 32], dtype=tf.float32)
    # num_factors = tf.shape(factor_entropies)[0]

    def estimate_cond_entropies(i, cond_entropies):
        # --- Add 1 because still contains color
        relevant_col = factors_batch[:, i]
        # --- Extract unique parameter values for that column and their
        # --- --- corresponding indices. y[idx[i]] = x[i]
        unique_vals, unique_idx = tf.unique(relevant_col)
        num_unique = tf.shape(unique_vals)[0]
        factor_support = tf.cast(num_unique, dtype=tf.float32)
        num_samples_factor = N // factor_support

        def calc_cond_entropy_dim(j, e_dim):
            # --- we mask the relevant_col array by checking where indices are equal to current index
            mask = tf.equal(unique_idx, j)
            # --- Extract indices that correspond to a value
            valid_samples = tf.boolean_mask(logqz_i, mask, axis=1)

            # --- Estimate entropies conditioned on indices
            # --- Divide by the number of unique values, i.e the entropy of factor
            e_dim += estimate_entropies(valid_samples, num_samples_factor)

            return (j + 1, e_dim)

        # --- Calculate entropy for each dimension of latent space
        _, e_dim = tf.while_loop(lambda j, _: j < num_unique,
                                 calc_cond_entropy_dim,
                                 (0, tf.zeros(latent_size, dtype=tf.float32)))

        e_dim = e_dim / factor_support

        cond_entropies = tf.reshape(
            tf.py_func(assign_val, (cond_entropies, i, e_dim), tf.float32),
            (num_factors, latent_size))

        return (i + 1, cond_entropies)

    _, cond_entropies = tf.while_loop(lambda i, _: i < num_factors,
                                      estimate_cond_entropies,
                                      (0,
                                       tf.zeros([num_factors, latent_size],
                                                dtype=tf.float32)))

    metric = compute_metric_shapes(marginal_entropies, cond_entropies,
                                   factor_support)

    return metric, marginal_entropies, cond_entropies
