from scipy import stats
import numpy as np
import torch
import matplotlib.pyplot as plt
plt.switch_backend('Agg')


def kde(values, fig_size=(8, 8), bbox=[-2, 2, -2, 2], xlabel="", ylabel="", cmap='Blues', show=False, save=None):
    
    fig, ax = plt.subplots(figsize=fig_size)
    
    kernel = stats.gaussian_kde(values)

    ax.axis(bbox)
    xx, yy = np.mgrid[bbox[0]:bbox[1]:300j, bbox[2]:bbox[3]:300j]
    positions = np.vstack([xx.ravel(), yy.ravel()])
    f = np.reshape(kernel(positions).T, xx.shape)
    cfset = ax.contourf(xx, yy, f, cmap=cmap)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_aspect('equal')
    plt.tight_layout()

    if save is not None: plt.savefig(save)
    if show: plt.show()

    plt.close()

def real_builder_circle(batch_size, sigma=0.05):
    skel = np.array([[np.sin(t), np.cos(t)]
                     for t in np.linspace(0,2*np.pi,9)[:-1]])
    mixture = np.random.choice(range(8), batch_size)
    real = skel[mixture] + sigma*np.random.randn(batch_size, 2)

    return real


def real_builder_diamond(batch_size, sigma=0.15):
    skel = np.array([[1.2*np.sin(t), 1.2*np.cos(t)]
                     for t in np.linspace(0,2*np.pi,5)[:-1]])
    mixture = np.random.choice(range(4), batch_size)
    real = skel[mixture] + sigma*np.random.randn(batch_size, 2)

    return real