import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
plt.ion()

def visualize_samples(samples, thinning=1, create_figure=True):
    if create_figure:
        plt.figure()
    [num_samples, num_dimensions] = samples.shape
    for i in range(0, num_samples, thinning):
        visualize_n_link(samples[i], num_dimensions, np.ones(num_dimensions))

def visualize_n_link(theta, num_dimensions, l, clear_fig=True):
    if clear_fig:
        plt.clf()
    plt.xlim([-0.2*num_dimensions,num_dimensions])
    plt.ylim([-0.5*num_dimensions,0.5*num_dimensions])

    x = [0]
    y = [0]
    for i in range(0, num_dimensions):
        y.append(y[-1] + l[i] * np.sin(np.sum(theta[:i+1])))
        x.append(x[-1] + l[i] * np.cos(np.sum(theta[:i+1])))
        plt.plot([x[-2], x[-1]], [y[-2],y[-1]], color='k', linestyle='-', linewidth=2)
    plt.plot(x[-1], y[-1], 'o')
    plt.plot(0.7*num_dimensions,0, 'rx')
    plt.pause(0.1)

def visualize_mixture(mixture_weights, mixture_means, l=None, clear_fig=True, markerPoses=[]):
    num_dimensions = len(mixture_means[0])
    if l is None:
        l = np.ones(num_dimensions)
    if clear_fig:
        plt.clf()

    plt.xlim([-0.2 * num_dimensions, num_dimensions])
    plt.ylim([-0.5 * num_dimensions, 0.5 * num_dimensions])
    plt.xlim([ -num_dimensions,num_dimensions])
    plt.ylim([-num_dimensions,num_dimensions])
    if np.max(mixture_weights) - np.min(mixture_weights) != 0:
        weights = mixture_weights - np.min(mixture_weights)
        weights = 0.1 + 0.9 * weights / (np.max(weights) - np.min(weights))
    else:
        weights = np.ones((len(mixture_weights)))


    for weight, theta in zip(weights, mixture_means):
        x = [0]
        y = [0]
        for i in range(0, num_dimensions):
            y.append(y[-1] + l[i] * np.sin(np.sum(theta[:i+1])))
            x.append(x[-1] + l[i] * np.cos(np.sum(theta[:i+1])))
            plt.plot([x[-2], x[-1]], [y[-2],y[-1]], color='k', linestyle='-', linewidth=2, alpha=weight, markersize=3)
        plt.plot(x[-1], y[-1], 'o', color="k", alpha=weight, markersize=6.1)
        plt.plot(x[-1], y[-1], 'o', color="red", alpha=weight, markersize=6)
    rect = patches.Rectangle((-0.25, -0.25), 0.5, 0.5, linewidth=1, edgecolor='k', fill=True, facecolor='dimgrey',
                             zorder=1000)
    ax = plt.gca()
    ax.add_patch(rect)
    [plt.plot(pose[0], pose[1], 'rx', markersize=10, mew=2) for pose in markerPoses]


