import os
import sys
import numpy as np
import logging

import tensorflow as tf
import matplotlib.pyplot as plt

from mvhg.tf_heaviside import heaviside
from mvhg.tf_mvhg import MVHG

tf.compat.v1.disable_eager_execution()

if __name__ == "__main__":
    dir_results = os.path.join(".", "results")
    log_level = logging.INFO
    logging.basicConfig(level=log_level)
    class_names = ["violet", "yellow"]
    colors = ["blueviolet", "yellow"]
    m = tf.Variable(tf.constant([10, 10], dtype=tf.float32))
    n = tf.Variable(tf.constant([10], dtype=tf.float32))
    w = tf.math.log(tf.Variable(tf.constant([1.0, 1.0], dtype=tf.float32)))
    tau = tf.Variable(tf.constant([1.0]))
    num_classes = 2
    num_samples = 10
    create_plot = True
    w = tf.repeat(tf.expand_dims(w, 0), repeats=num_samples, axis=0)
    n_repeats = 1
    # central_experiment(m, n, n_exp)
    # noncentral_experiment(m, n, w, n_exp)
    mvhg = MVHG(num_classes, m, n)

    for h in range(n_repeats):
        init = tf.compat.v1.global_variables_initializer()
        sess = tf.compat.v1.Session()
        with tf.compat.v1.Session() as sess:
            sess.run(init)
            y, x, _, log_p = mvhg(w)
            # y, x, _ = pmf_noncentral_fmvhg(m, n, w, tau, num_classes)
            # p, log_p = get_probability(x, m, n, w, num_classes)
            print(log_p.eval())
            if not create_plot:
                print()
                for i in range(num_samples):
                    print(h, i)
                    sum_sample = 0
                    for j in range(num_classes):
                        print(y[j][i], x[j][i])
                        sum_sample += x[j][i]
                    if sum_sample != n[0]:
                        print("ERROR !!!! ")
                        sys.exit()
                    else:
                        print("correct: ", sum_sample)
            if create_plot:
                str_ws = [str(w_j) for w_j in list(w[0].eval())]
                str_weights = "_".join(str_ws)
                fn_plot = os.path.join(
                    dir_results, "tf_samples_f_" + str_weights + ".png"
                )
                fig = plt.figure()
                ax = fig.add_subplot(1, 1, 1)
                for j in range(num_classes):
                    ind_j = np.arange(m[j].eval() + 1)
                    y_avg_j = np.array(
                        tf.math.reduce_mean(y[j], axis=0).eval()
                    ).flatten()
                    ax.bar(
                        ind_j, y_avg_j, alpha=0.5, color=colors[j], label=class_names[j]
                    )
                plt.title(str_ws)
                plt.legend()
                fig.tight_layout()
                plt.draw()
                plt.savefig(fn_plot, format="png")
