import sys
import numpy as np
import logging

import tensorflow as tf
import matplotlib.pyplot as plt

from mvhg.tf_heaviside import heaviside


def sample_gumbel(shape, eps=1e-20):
    U = tf.random.uniform(shape, minval=0, maxval=1)
    return -tf.math.log(-tf.math.log(U + eps) + eps)


def gumbel_softmax(logits, temp, hard=True):
    gs_sample = logits + sample_gumbel(tf.shape(logits))
    y = tf.nn.softmax(gs_sample / temp)
    if hard:
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 2, keepdims=True)), y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y


def get_logits(n, m, m2, log_w1, log_w2, hside_check, eps=1e-20):
    n_samples = tf.shape(log_w1)[0]
    x_all = tf.range(m + 1)
    x_all = tf.repeat(tf.expand_dims(x_all, 0), repeats=n_samples, axis=0)
    x_all = tf.expand_dims(x_all, 1)
    log_x1_fact = tf.math.lgamma(x_all + 1.0)
    log_m1_x1_fact = tf.math.lgamma(m - x_all + 1.0)
    log_x2_fact = tf.math.lgamma(tf.nn.relu(n - x_all) + 1.0)
    log_m2_x2_fact = tf.math.lgamma(m2 - tf.nn.relu(n - x_all) + 1.0)
    log_m_x = log_x1_fact + log_x2_fact + log_m1_x1_fact + log_m2_x2_fact
    x1_log_w1 = x_all * log_w1
    x1_log_w2 = (n - x_all) * log_w2
    log_p_shifted = x1_log_w1 + x2_log_w2 - log_m_x
    log_p_shifted_sup = log_p_shifted + tf.math.log(hside_check)
    return log_p_shifted_sup


def calc_group_w_m(m_group, log_w_group):
    m_G = tf.math.reduce_sum(m_group)
    lse_arg = log_w_group + tf.math.log(m_group) - tf.math.log(m_G)
    log_w_G = tf.reduce_logsumexp(lse_arg, axis=2)
    return log_w_G, m_G


def get_hside_check(n, x_i, m_r):
    check_x_i = n - x_i
    check_x_rest = m_r - tf.nn.relu(n - x_i)
    hside_check_x_i = heaviside(check_x_i)
    hside_check_x_rest = heaviside(check_x_rest)
    hside_check = hside_check_x_i * hside_check_x_rest
    return hside_check


def pmf_noncentral_fmvhg(m_all, n, log_w_all, temperature, n_c):
    n = tf.expand_dims(n, 1)
    n_samples = tf.shape(w_all)[0]
    n_out = tf.zeros((n_samples, 1, 1), dtype=tf.float32)
    m_out = tf.zeros((1), dtype=tf.float32)
    w_all = tf.expand_dims(w_all, 1)
    y_all = []
    y_mask_all = []
    x_all = []
    for i in range(0, n_c):
        n_new = n - n_out
        m_i = m_all[i]
        log_w_i = tf.expand_dims(log_w_all[:, :, i], 1)
        x_i = tf.range(m_i + 1)
        x_i = tf.repeat(
            tf.expand_dims(tf.expand_dims(x_i, 0), 1), repeats=n_samples, axis=0
        )
        if i + 1 < n_c:
            m_rest = tf.math.reduce_sum(m_all) - m_out - m_i
        else:
            m_rest = tf.zeros((1), dtype=tf.float32)
        N_new = m_i + m_rest
        m_rest_ind = m_all[i + 1 :]
        if i + 1 < n_c:
            log_w_rest_ind = w_all[:, :, i + 1 :]
            log_w_rest, m_rest = calc_group_w_m(m_rest_ind, log_w_rest_ind)
        else:
            w = tf.zeros((n_samples, 1, 1), dtype=tf.float32)

        x_rest = tf.nn.relu(n_new - x_i)
        hside_check = get_hside_check(n_new, x_i, m_rest)
        if i + 1 < n_c:
            logits_p_x_i = get_logits(
                n_new, m_i, m_rest, log_w_i, log_w_rest, hside_check
            )
            p_x_i = tf.nn.softmax(logits_p_x_i)
            y_i = gumbel_softmax(logits_p_x_i, temperature)
        else:
            y_i = tf.dtypes.cast(hside_check, tf.float32)
        y_all.append(y_i)
        ones = tf.ones(
            (tf.dtypes.cast(n_samples, tf.float32), int(m_i + 1), int(m_i + 1))
        )
        lt = tf.linalg.LinearOperatorLowerTriangular(ones).to_dense()
        mask_filled = tf.linalg.matmul(y_i, lt)
        y_mask_all.append(mask_filled)
        x_i = tf.expand_dims(tf.reduce_sum(mask_filled, axis=2), 2) - 1.0
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i
    # assert tf.reduce_all(n- n_out == 0), 'num elements samples NOT correct'
    return y_all, x_all, y_mask_all


if __name__ == "__main__":
    log_level = logging.INFO
    logging.basicConfig(level=log_level)
    class_names = ["violet", "yellow", "green"]
    colors = ["blueviolet", "yellow", "green"]
    m = tf.Variable(tf.constant([7, 7, 7], dtype=tf.float32))
    n = tf.Variable(tf.constant([10], dtype=tf.float32))
    w = tf.math.log(tf.Variable(tf.constant([1.0, 1.0, 1.0], dtype=tf.float32)))
    tau = tf.Variable(tf.constant([1.0]))
    num_classes = 3
    num_samples = 10000
    create_plot = True
    n = tf.expand_dims(tf.repeat(n, repeats=num_samples), 1)
    w = tf.repeat(tf.expand_dims(w, 0), repeats=num_samples, axis=0)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)
    for h in range(n_repeats):
        y, x = mvhg.tf_fmvhg.pmf_noncentral_fmvhg(m, n, w, tau, num_classes)
        if not create_plot:
            print()
            for i in range(num_samples):
                print(h, i)
                sum_sample = 0
                for j in range(num_classes):
                    print(y[j][i], x[j][i])
                    sum_sample += x[j][i]
                if sum_sample != n[0]:
                    print("ERROR !!!! ")
                    sys.exit()
                else:
                    print("correct: ", sum_sample)
        if create_plot:
            str_ws = [str(w_j) for w_j in list(w[0].eval())]
            str_weights = "_".join(str_ws)
            fn_plot = "./results/tf_samples_f_" + str_weights + ".png"
            fig = plt.figure()
            ax = fig.add_subplot(1, 1, 1)
            for j in range(num_classes):
                ind_j = np.arange(m[j].eval() + 1)
                y_avg_j = np.array(tf.math.reduce_mean(y[j], axis=0).eval()).flatten()
                ax.bar(ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j])
            plt.title(str_ws)
            plt.legend()
            fig.tight_layout()
            plt.draw()
            plt.savefig(fn_plot, format="png")
