from inits import *
import matplotlib.pyplot as plt

logc = np.log(2.*np.pi)
c = - 0.5 * np.log(2*np.pi)


def tf_normal_logpdf(x, mu, log_sigma_sq):
    return - 0.5 * logc - log_sigma_sq / 2. - tf.div(tf.square(tf.subtract(x, mu)), 2 * tf.exp(log_sigma_sq))


def tf_gaussian_ent(log_sigma_sq):
    return - 0.5 * (logc + 1.0 + log_sigma_sq)


def tf_gaussian_marg(mu, log_sigma_sq):
    return - 0.5 * (logc + (tf.square(mu) + tf.exp(log_sigma_sq)))


def plot_embedding(X, y, d, title=None):
    """Plot an embedding X with the class label y colored by the domain d."""
    x_min, x_max = np.min(X, 0), np.max(X, 0)
    X = (X - x_min) / (x_max - x_min)

    # Plot colors numbers
    plt.figure(figsize=(10, 10))
    ax = plt.subplot(111)
    for i in range(X.shape[0]):
        # plot colored number
        plt.text(X[i, 0], X[i, 1], str(y[i]),
                 color=plt.cm.bwr(d[i] / 1.),
                 fontdict={'weight': 'bold', 'size': 9})

    plt.xticks([]), plt.yticks([])
    if title is not None:
        plt.title(title)
    plt.tight_layout()
    plt.savefig("TransLATE.pdf", bbox_inches="tight")
    plt.show()


def generate_data(num_pos, num_neg, theta=0.0):
    move_p = np.tile([1.5*math.cos(theta), 1.5*math.sin(theta)], [num_pos, 1])
    move_n = np.tile([1.5*math.cos(theta + math.pi), 1.5*math.sin(theta + math.pi)], [num_pos, 1])

    X_p = move_p + 0.5 * np.random.randn(num_pos, 2)
    X_n = move_n + 0.5 * np.random.randn(num_pos, 2)
    X = np.concatenate([X_p, X_n], axis=0)
    labels = np.vstack([np.tile([1., 0.], [num_pos, 1]), np.tile([0., 1.], [num_neg, 1])])

    return X, labels
