import torch
from matplotlib import pyplot as pl
import numpy as np

def get_mesh_grid(x_range = [0.0, 1.0], y_range=[0.0, 1.0], x_num=10, y_num=10):
    plot_lim = [x_range[0], x_range[1], y_range[0], y_range[1]]
    x_axis = np.linspace(x_range[0], x_range[1], num=x_num)
    y_axis = np.linspace(y_range[0], y_range[1], num=y_num)
    plotx, ploty = torch.meshgrid(torch.Tensor(x_axis).float(), torch.Tensor(y_axis).float())
    return torch.cat((plotx.unsqueeze(-1), ploty.unsqueeze(-1)), -1).float(), plot_lim

def density_plot(plot_func, range_opt={'x_range' : [0.0, 1.0], 'y_range':[0.0, 1.0], 'x_num': 10, 'y_num':10}, show=True, cmap=None, vmin=None, vmax=None, cbar=True, legend=None):
    x_range = range_opt['x_range']
    y_range = range_opt['y_range']
    x_num = range_opt['x_num']
    y_num = range_opt['y_num']
    plot_grid, plot_lim = get_mesh_grid(x_range, y_range, x_num, y_num)
    plot_vals = plot_func(plot_grid.view(x_num*y_num, 2))
    pcm = pl.imshow(np.transpose(plot_vals.view(x_num, y_num).detach().numpy()), extent=plot_lim, cmap=cmap, vmin=vmin, vmax=vmax, origin='lower', label=legend)
    if cbar:
        pl.colorbar()
    if show:
        pl.show()
    return pcm

def contour_plot(plot_func, range_opt={'x_range' : [0.0, 1.0], 'y_range':[0.0, 1.0], 'x_num': 10, 'y_num':10}, show=True, color_opt = 'black',linestyles=None, levels=None, label=True, outline=False, legend=None):
    x_range = range_opt['x_range']
    y_range = range_opt['y_range']
    x_num = range_opt['x_num']
    y_num = range_opt['y_num']
    plot_grid, plot_lim = get_mesh_grid(x_range, y_range, x_num, y_num)
    plot_vals = plot_func(plot_grid.view(x_num*y_num, 2))
    if levels is not None:
        if outline:
            outlines = pl.contour(plot_grid[:,:,0].detach().numpy(),plot_grid[:,:,1].detach().numpy(), plot_vals.view(x_num, y_num).detach().numpy(),levels, colors='white')
        contours = pl.contour(plot_grid[:,:,0].detach().numpy(),plot_grid[:,:,1].detach().numpy(), plot_vals.view(x_num, y_num).detach().numpy(),levels, colors=color_opt, linestyles=linestyles)
    else:
        if outline:
            outlines = pl.contour(plot_grid[:,:,0].detach().numpy(),plot_grid[:,:,1].detach().numpy(), plot_vals.view(x_num, y_num).detach().numpy(), colors='white')
        contours = pl.contour(plot_grid[:,:,0].detach().numpy(),plot_grid[:,:,1].detach().numpy(), plot_vals.view(x_num, y_num).detach().numpy(), colors=color_opt, linestyles=linestyles)
    if label:
        pl.clabel(contours, inline=True, fontsize=8, rightside_up=True)
    if legend is not None:
        contours.collections[0].set_label(legend)
    if show:
        pl.show()
