import sys
import numpy as np
import logging

import tensorflow as tf
import matplotlib.pyplot as plt

from mvhg.tf_heaviside import heaviside

tf.compat.v1.disable_eager_execution()


def sample_gumbel(shape, eps=1e-20):
    U = tf.random.uniform(shape, minval=0, maxval=1)
    return -tf.math.log(-tf.math.log(U + eps) + eps)


def gumbel_softmax(logits, temp, hard=True):
    gs_sample = logits + sample_gumbel(tf.shape(logits))
    y = tf.nn.softmax(gs_sample / temp)
    if hard:
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 2, keepdims=True)), y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y


def get_logits(n, m, m2, log_w1, log_w2, hside_check, eps=1e-20):
    n_samples = tf.shape(log_w1)[0]
    x_all = tf.range(m + 1)
    x_all = tf.repeat(tf.expand_dims(x_all, 0), repeats=n_samples, axis=0)
    x_all = tf.expand_dims(x_all, 1)
    log_x1_fact = tf.math.lgamma(x_all + 1.0)
    log_m1_x1_fact = tf.math.lgamma(m - x_all + 1.0)
    log_x2_fact = tf.math.lgamma(tf.nn.relu(n - x_all) + 1.0)
    log_m2_x2_fact = tf.math.lgamma(m2 - tf.nn.relu(n - x_all) + 1.0)
    log_m_x = log_x1_fact + log_x2_fact + log_m1_x1_fact + log_m2_x2_fact
    x1_log_w1 = x_all * log_w1
    x2_log_w2 = (n - x_all) * log_w2
    log_p_shifted = x1_log_w1 + x2_log_w2 - log_m_x
    log_p_shifted_sup = log_p_shifted + tf.math.log(hside_check)
    return log_p_shifted_sup


def calc_group_w_m(m_group, log_w_group):
    m_G = tf.math.reduce_sum(m_group)
    lse_arg = log_w_group + tf.math.log(m_group) - tf.math.log(m_G)
    log_w_G = tf.reduce_logsumexp(lse_arg, axis=2)
    return log_w_G, m_G


def get_hside_check(n, x_i, m_r):
    check_x_i = n - x_i
    check_x_rest = m_r - tf.nn.relu(n - x_i)
    hside_check_x_i = heaviside(check_x_i)
    hside_check_x_rest = heaviside(check_x_rest)
    hside_check = hside_check_x_i * hside_check_x_rest
    return hside_check


def pmf_noncentral_fmvhg(m_all, n, log_w_all, temperature, n_c):
    n = tf.expand_dims(n, 1)
    n_samples = tf.shape(log_w_all)[0]
    n_out = tf.zeros((n_samples, 1, 1), dtype=tf.float32)
    m_out = tf.zeros((1), dtype=tf.float32)
    log_w_all = tf.expand_dims(log_w_all, 1)
    y_all = []
    y_mask_all = []
    x_all = []
    for i in range(0, n_c):
        n_new = n - n_out
        m_i = m_all[i]
        log_w_i = tf.expand_dims(log_w_all[:, :, i], 1)
        x_i = tf.range(m_i + 1)
        x_i = tf.repeat(
            tf.expand_dims(tf.expand_dims(x_i, 0), 1), repeats=n_samples, axis=0
        )
        if i + 1 < n_c:
            m_rest = tf.math.reduce_sum(m_all) - m_out - m_i
        else:
            m_rest = tf.zeros((1), dtype=tf.float32)
        N_new = m_i + m_rest
        m_rest_ind = m_all[i + 1 :]
        if i + 1 < n_c:
            log_w_rest_ind = log_w_all[:, :, i + 1 :]
            log_w_rest, m_rest = calc_group_w_m(m_rest_ind, log_w_rest_ind)
            log_w_rest = tf.expand_dims(log_w_rest, 2)
        else:
            w = tf.zeros((n_samples, 1, 1), dtype=tf.float32)

        x_rest = tf.nn.relu(n_new - x_i)
        hside_check = get_hside_check(n_new, x_i, m_rest)
        if i + 1 < n_c:
            logits_p_x_i = get_logits(
                n_new, m_i, m_rest, log_w_i, log_w_rest, hside_check
            )
            p_x_i = tf.nn.softmax(logits_p_x_i)
            y_i = gumbel_softmax(logits_p_x_i, temperature)
        else:
            y_i = tf.dtypes.cast(hside_check, tf.float32)
        y_all.append(y_i)
        ones = tf.ones(
            (
                tf.dtypes.cast(n_samples, tf.float32),
                tf.dtypes.cast(m_i + 1, tf.int32),
                tf.dtypes.cast(m_i + 1, tf.int32),
            )
        )
        lt = tf.linalg.LinearOperatorLowerTriangular(ones).to_dense()
        mask_filled = tf.linalg.matmul(y_i, lt)
        y_mask_all.append(mask_filled)
        x_i = tf.expand_dims(tf.reduce_sum(mask_filled, axis=2), 2) - 1.0
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i
    # assert tf.reduce_all(n- n_out == 0), 'num elements samples NOT correct'
    return y_all, x_all, y_mask_all


def get_probability(X_all, m_all, n, log_w_all, n_c):
    n = tf.expand_dims(n, 1)
    n_samples = tf.shape(log_w_all)[0]
    n_out = tf.zeros((n_samples, 1, 1), dtype=tf.float32)
    m_out = tf.zeros((1), dtype=tf.float32)
    log_w_all = tf.expand_dims(log_w_all, 1)
    log_p_x_all = []
    for i in range(0, n_c - 1):
        x_sel = tf.cast(tf.squeeze(X_all[i]), tf.int32)
        x_sel = tf.expand_dims(x_sel, 1)
        ind_sel = tf.expand_dims(tf.range(n_samples), 1)
        x_sel = tf.concat([ind_sel, x_sel], axis=1)

        n_new = n - n_out
        m_i = m_all[i]
        log_w_i = tf.expand_dims(log_w_all[:, :, i], 1)
        x_i = tf.range(m_i + 1)
        x_i = tf.repeat(
            tf.expand_dims(tf.expand_dims(x_i, 0), 1), repeats=n_samples, axis=0
        )
        if i + 1 < n_c:
            m_rest = tf.math.reduce_sum(m_all) - m_out - m_i
        else:
            m_rest = tf.zeros((1), dtype=tf.float32)
        N_new = m_i + m_rest
        m_rest_ind = m_all[i + 1 :]
        if i + 1 < n_c:
            log_w_rest_ind = log_w_all[:, :, i + 1 :]
            log_w_rest, m_rest = calc_group_w_m(m_rest_ind, log_w_rest_ind)
        else:
            w = tf.zeros((n_samples, 1, 1), dtype=tf.float32)

        x_rest = tf.nn.relu(n_new - x_i)
        hside_check = get_hside_check(n_new, x_i, m_rest)
        logits_p_x_i = get_logits(n_new, m_i, m_rest, log_w_i, log_w_rest, hside_check)
        p_x_i = tf.nn.softmax(logits_p_x_i)
        log_p_x_i = tf.math.log(p_x_i)
        log_p_x_i = tf.squeeze(log_p_x_i)
        log_p_x_sel = tf.gather_nd(log_p_x_i, x_sel)
        log_p_x_sel = tf.expand_dims(log_p_x_sel, 1)
        log_p_x_all.append(log_p_x_sel)
        n_out += X_all[i]
        m_out += m_i
    log_p_x_all = tf.concat(log_p_x_all, axis=1)
    log_p = tf.math.reduce_sum(log_p_x_all, axis=1, keepdims=True)
    p = tf.math.exp(log_p)
    return p


if __name__ == "__main__":
    log_level = logging.INFO
    logging.basicConfig(level=log_level)
    class_names = ["violet", "yellow", "green"]
    colors = ["blueviolet", "yellow", "green"]
    m = tf.Variable(tf.constant([7, 7, 7], dtype=tf.float32))
    n = tf.Variable(tf.constant([10], dtype=tf.float32))
    w = tf.math.log(tf.Variable(tf.constant([1.0, 1.0, 1.0], dtype=tf.float32)))
    tau = tf.Variable(tf.constant([1.0]))
    num_classes = 3
    num_samples = 10000
    create_plot = True
    n = tf.expand_dims(tf.repeat(n, repeats=num_samples), 1)
    w = tf.repeat(tf.expand_dims(w, 0), repeats=num_samples, axis=0)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)

    for h in range(n_repeats):
        init = tf.compat.v1.global_variables_initializer()
        sess = tf.compat.v1.Session()
        with tf.compat.v1.Session() as sess:
            sess.run(init)
            y, x, _ = pmf_noncentral_fmvhg(m, n, w, tau, num_classes)
            p = get_probability(x, m, n, w, num_classes)
            if not create_plot:
                print()
                for i in range(num_samples):
                    print(h, i)
                    sum_sample = 0
                    for j in range(num_classes):
                        print(y[j][i], x[j][i])
                        sum_sample += x[j][i]
                    if sum_sample != n[0]:
                        print("ERROR !!!! ")
                        sys.exit()
                    else:
                        print("correct: ", sum_sample)
            if create_plot:
                str_ws = [str(w_j) for w_j in list(w[0].eval())]
                str_weights = "_".join(str_ws)
                fn_plot = "./results/tf_samples_f_" + str_weights + ".png"
                fig = plt.figure()
                ax = fig.add_subplot(1, 1, 1)
                for j in range(num_classes):
                    ind_j = np.arange(m[j].eval() + 1)
                    y_avg_j = np.array(
                        tf.math.reduce_mean(y[j], axis=0).eval()
                    ).flatten()
                    ax.bar(
                        ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j]
                    )
                plt.title(str_ws)
                plt.legend()
                fig.tight_layout()
                plt.draw()
                plt.savefig(fn_plot, format="png")
