
import sys
import numpy as np
import logging

import tensorflow as tf
import matplotlib.pyplot as plt

from heaviside import heaviside


def sample_gumbel(shape, eps = 1e-20): 
    U = tf.random.uniform(shape, minval=0, maxval=1)
    return -tf.math.log(-tf.math.log(U + eps) + eps)

def gumbel_softmax(logits, temp, hard=True):
    gs_sample = logits + sample_gumbel(tf.shape(logits))
    y = tf.nn.softmax(gs_sample / temp)
    if hard:
        y_hard = tf.cast(tf.equal(y, tf.reduce_max(y, 2, keepdims=True)),
                         y.dtype)
        y = tf.stop_gradient(y_hard - y) + y
    return y


def get_logits(x, n, m, m2, w, hside_check, eps=1e-20):
    n_samples = tf.shape(w)[0]
    x_all = tf.range(x + 1)
    x_all = tf.repeat(tf.expand_dims(x_all, 0), repeats=n_samples, axis=0)
    x_all = tf.expand_dims(x_all, 1)
    log_x1_fact = tf.math.lgamma(x_all + 1.0)
    log_m1_x1_fact = tf.math.lgamma(m - x_all + 1.0)
    log_x2_fact = tf.math.lgamma(tf.nn.relu(n - x_all) + 1.0)
    log_m2_x2_fact = tf.math.lgamma(m2 - tf.nn.relu(n - x_all) + 1.0)
    log_m_x = log_x1_fact + log_x2_fact + log_m1_x1_fact + log_m2_x2_fact
    x1_log_w = x_all * tf.math.log(w)
    log_p_shifted = x1_log_w - log_m_x
    log_p_shifted_sup = log_p_shifted + tf.math.log(hside_check)
    logging.info('log_p_shifted.shape: \n' + str(tf.shape(log_p_shifted).eval()))
    logging.info('log_p_shifted: \n' + str(log_p_shifted.eval()[0]))
    logging.info('log_p_shifted_sup.shape: \n' +
                 str(tf.shape(log_p_shifted_sup).eval()))
    logging.info('log_p_shifted_sup: \n' + str(log_p_shifted_sup.eval()[0]))
    return log_p_shifted_sup


def pmf_noncentral_fmvhg(m_all, n, w_all, temperature, n_c):
    n = tf.expand_dims(n, 1)
    n_samples = tf.shape(w_all)[0]
    n_out = tf.zeros((n_samples, 1, 1), dtype = tf.float32)
    m_out = tf.zeros((1), dtype=tf.float32)
    w_all = tf.expand_dims(w_all, 1)
    y_all = []
    y_mask_all = []
    x_all = []
    for i in range(0, n_c):
        n_new = n - n_out
        m_i = m_all[i]
        w_i = tf.expand_dims(w_all[:, :, i], 1)
        x_i = tf.range(m_i + 1)
        x_i = tf.repeat(tf.expand_dims(tf.expand_dims(x_i, 0), 1),
                        repeats=n_samples, axis=0)
        if i+1 < n_c:
            m_rest = tf.math.reduce_sum(m_all) - m_out - m_i
        else:
            m_rest = tf.zeros((1), dtype = tf.float32)
        N_new = m_i + m_rest
        if i+1 < n_c:
            w_rest_enum = tf.math.reduce_sum(w_all[:, :, i+1:]*m_all[i+1:],
                                             axis=2)
            w_rest_enum = tf.expand_dims(w_rest_enum, 1)
            w_rest_denom = N_new - m_i
            w_rest = w_rest_enum/w_rest_denom
            w = w_i/w_rest
        else:
            w = tf.zeros((n_samples, 1, 1), dtype=tf.float32)

        x_rest = tf.nn.relu(n_new - x_i)
        # DEBUG: Implement differentiable way to enforce constraints
        # TODO: check for differentiability
        check_x_i = n_new - x_i
        check_x_rest = m_rest - x_rest
        hside_check_x_i = heaviside(check_x_i)
        hside_check_x_rest = heaviside(check_x_rest)
        hside_check = hside_check_x_i * hside_check_x_rest
        logging.info('hside_check: \n' + str(hside_check.eval()[0]))
        if i+1 < n_c:
            logging.info('m_i: \n' + str(m_i.eval()))
            logging.info('m_rest: \n' + str(m_rest.eval()))
            logging.info('n_i: \n' + str(n_new.eval()[0]))
            logging.info('x_i: \n' + str(x_i.eval()[0]))
            logging.info('x_rest: \n' + str(x_rest.eval()[0]))
            logging.info('w_i: \n' + str(w_i.eval()[0]))
            logging.info('w: \n' + str(w.eval()[0]))

            logits_p_x_i = get_logits(m_i, n_new, m_i, m_rest, w, hside_check)
            logging.info('logits(p_x): \n' + str(logits_p_x_i.eval()[0]))
            logging.info('shape(logits(p_x)): \n' + str(tf.shape(logits_p_x_i).eval()))
            p_x_i = tf.nn.softmax(logits_p_x_i)
            logging.info('p_x_i: \n' + str(p_x_i.eval()[0]))
            logging.info('sum(p_x_i): \n' + str(tf.reduce_sum(p_x_i,
                                                              axis=2).eval()[0]))
            y_i = gumbel_softmax(logits_p_x_i, temperature)
            logging.info('y_i: \n' + str(y_i.eval()[0]))
        else:
            y_i = tf.dtypes.cast(hside_check, tf.float32)
            logging.info('y_i: \n' + str(y_i.eval()[0]))
        y_all.append(y_i)
        ones = tf.ones((tf.dtypes.cast(n_samples, tf.float32), m_i+1, m_i+1))
        lt = tf.linalg.LinearOperatorLowerTriangular(ones).to_dense()
        mask_filled = tf.linalg.matmul(y_i, lt)
        y_mask_all.append(mask_filled)
        logging.info('mask_filled_i: \n' + str(mask_filled.eval()[0]))
        x_i = tf.expand_dims(tf.reduce_sum(mask_filled, axis=2), 2) - 1.0
        logging.info('x_i: \n' + str(x_i.eval()[0]))
        x_all.append(x_i)
        n_out += x_i
        m_out += m_i
    # assert tf.reduce_all(n- n_out == 0), 'num elements samples NOT correct'
    return y_all, x_all, y_mask_all


if __name__ == '__main__':
    log_level = logging.INFO
    logging.basicConfig(level=log_level)
    class_names = ['violet', 'yellow', 'green']
    colors = ['blueviolet', 'yellow', 'green']
    m = tf.Variable(tf.constant([7, 7, 7], dtype=tf.float32))
    n = tf.Variable(tf.constant([10], dtype=tf.float32))
    w = tf.Variable(tf.constant([1.0, 1.0, 1.0], dtype=tf.float32))
    tau = tf.Variable(tf.constant([1.0]))
    num_classes = 3
    num_samples = 10000
    create_plot = True
    n = tf.expand_dims(tf.repeat(n, repeats=num_samples), 1)
    w = tf.repeat(tf.expand_dims(w, 0), repeats=num_samples, axis=0)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)
    for h in range(n_repeats):
        init = tf.global_variables_initializer()
        sess = tf.compat.v1.Session()
        with tf.compat.v1.Session() as sess:
            sess.run(init)
            y, x = pmf_noncentral_fmvhg(m, n, w, tau, num_classes)
            if not create_plot:
                print()
                for i in range(num_samples):
                    print(h, i)
                    sum_sample = 0
                    for j in range(num_classes):
                        print(y[j][i], x[j][i])
                        sum_sample += x[j][i]
                    if sum_sample != n[0]:
                        print('ERROR !!!! ')
                        sys.exit()
                    else:
                        print('correct: ', sum_sample)
            if create_plot:
                str_ws = [str(w_j) for w_j in list(w[0].eval())]
                str_weights = '_'.join(str_ws)
                fn_plot = './results/tf_samples_f_' + str_weights + '.png'
                fig = plt.figure()
                ax = fig.add_subplot(1,1,1)
                for j in range(num_classes):
                    ind_j = np.arange(m[j].eval()+1)
                    y_avg_j = np.array(tf.math.reduce_mean(y[j], axis=0).eval()).flatten()
                    ax.bar(ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j])
                plt.title(str_ws)
                plt.legend()
                fig.tight_layout()
                plt.draw()
                plt.savefig(fn_plot, format='png')



