import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import seaborn as sns; sns.set()

def plot_histogram(flat_array, num_bins, title, save_path):
    fig, ax = plt.subplots(1)
    ax.set_title(title)
    plt.hist(flat_array, bins=num_bins)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def plot_2dhistogram(x, y, num_bins, title, save_path, ax_lims=None):
    fig, ax = plt.subplots(1)
    ax.set_title(title)
    plt.hist2d(x, y, bins=num_bins)
    if ax_lims is not None:
        ax.set_xlim(ax_lims[0])
        ax.set_ylim(ax_lims[1])
    ax.set_aspect('equal')    
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def plot_seaborn_heatmap(x, y, num_bins, title, save_path, ax_lims=None):
    g = sns.kdeplot(x, y, cbar=True, cmap='RdBu')
    # g.set(title=title, xlim=tuple(ax_lims[0]), ylim=tuple(ax_lims[1]))
    g.set(xlim=tuple(ax_lims[0]), ylim=tuple(ax_lims[1]))
    g.figure.savefig(save_path)
    # g.figure.close()
    plt.close()

def plot_scatter(x, y, num_bins, title, save_path, ax_lims=None):
    fig, ax = plt.subplots(1)
    # ax.set_title(title)
    plt.scatter(x, y, s=0.5)
    if ax_lims is not None:
        ax.set_xlim(ax_lims[0])
        ax.set_ylim(ax_lims[1])
    ax.set_aspect('equal')    
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()

def plot_seaborn_grid(grid, vmin, vmax, title, save_path):
    # ax = sns.heatmap(grid, vmin=vmin, vmax=vmax, cmap="YlGnBu")
    ax = sns.heatmap(grid, vmin=vmin, vmax=vmax, cmap="RdBu")
    # ax.set(title=title)
    ax.figure.savefig(save_path)
    plt.close()

def save_pytorch_tensor_as_img(tensor, save_path):
    if tensor.size(0) == 1: tensor = tensor.repeat(3, 1, 1)
    fig, ax = plt.subplots(1)
    ax.imshow(np.transpose(tensor.numpy(), (1,2,0)))
    plt.savefig(save_path)
    plt.close()


def generate_gif(list_of_img_list, names, save_path):
    fig, axarr = plt.subplots(len(list_of_img_list))
    def update(t):
        for j in range(len(list_of_img_list)):
            axarr[j].imshow(list_of_img_list[j][t])
            axarr[j].set_title(names[j])
        return axarr
    anim = FuncAnimation(fig, update, frames=np.arange(len(list_of_img_list[0])), interval=2000)
    anim.save(save_path, dpi=80, writer='imagemagick')
    plt.close()


def get_cmap(n, name='hsv'):
    '''Returns a function that maps each index in 0, 1, ..., n-1 to a distinct 
    RGB color; the keyword argument name must be a standard mpl colormap name.'''
    # for some weird reason 0 and 2 look almost identical
    n += 1
    cmap = plt.cm.get_cmap(name, n)
    # def new_cmap(n):
    #     if n >= 3: n = n+1
    #     if n == 1:
    #         return (0,0,0,1)
    #     else:
    #         return cmap(n)
    # return new_cmap
    return cmap


def plot_returns_on_same_plot(arr_list, names, title, save_path, x_axis_lims=None, y_axis_lims=None):
    # print(arr_list, names, title, save_path, y_axis_lims)
    fig, ax = plt.subplots(1)
    cmap = get_cmap(len(arr_list))
    for i in range(len(arr_list)): cmap(i)

    for i, v in enumerate(zip(arr_list, names)):
        ret, name = v
        if ret.size <= 1: continue
        ax.plot(np.arange(ret.shape[0]), ret, color=cmap(i), label=name)

    ax.set_title(title)
    if x_axis_lims is not None:
        ax.set_xlim(x_axis_lims)
    if y_axis_lims is not None:
        ax.set_ylim(y_axis_lims)
    lgd = ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), shadow=False, ncol=3)
    plt.savefig(save_path, bbox_extra_artists=(lgd,), bbox_inches='tight')
    plt.close()


def plot_multiple_plots(plot_list, names, title, save_path):
    fig, ax = plt.subplots(1)
    cmap = get_cmap(len(plot_list))

    for i, v in enumerate(zip(plot_list, names)):
        plot, name = v
        ax.plot(plot[0], plot[1], color=cmap(i), label=name)

    ax.set_title(title)
    lgd = ax.legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), shadow=False, ncol=3)
    plt.savefig(save_path, bbox_extra_artists=(lgd,), bbox_inches='tight')
    plt.close()


def save_plot(x, y, title, save_path, color='cyan', x_axis_lims=None, y_axis_lims=None):
    fig, ax = plt.subplots(1)
    ax.plot(x, y, color=color)
    ax.set_title(title)
    if x_axis_lims is not None:
        ax.set_xlim(x_axis_lims)
    if y_axis_lims is not None:
        ax.set_ylim(y_axis_lims)
    plt.savefig(save_path, bbox_inches='tight')
    plt.close()


def plot_forward_reverse_KL_rews():
    plt.rcParams.update({'font.size': 16})
    # plt.rcParams.update({'lines.linewidth': 4})
    plot_line_width = 4
    line_color = 'deepskyblue'

    # reverse KL
    fig, ax = plt.subplots(1)
    ax.plot(np.arange(-10,10,0.05), np.arange(-10,10,0.05), color=line_color, linewidth=plot_line_width)
    ax.set_xlim([-10,10])
    ax.set_ylim([-12,12])
    ax.set_xlabel(r'log$\frac{\rho^{exp}(s,a)}{\rho^\pi(s,a)}$', fontsize='xx-large')
    ax.set_ylabel('$r(s,a)$', fontsize='xx-large')
    plt.axhline(0, color='grey')
    plt.axvline(0, color='grey')
    plt.savefig('plots/junk_vis/rev_KL_rew.png', bbox_inches='tight', dpi=150)
    plt.close()

    # GAIL
    fig, ax = plt.subplots(1)
    x = np.arange(-10,10,0.05)
    y = -np.log(1 + np.exp(-x))
    ax.plot(x, y, color=line_color, linewidth=plot_line_width)
    ax.set_xlim([-10,10])
    ax.set_ylim([-12,12])
    ax.set_xlabel(r'log$\frac{\rho^{exp}(s,a)}{\rho^\pi(s,a)}$', fontsize='xx-large')
    ax.set_ylabel('$r(s,a)$', fontsize='xx-large')
    plt.axhline(0, color='grey')
    plt.axvline(0, color='grey')
    plt.savefig('plots/junk_vis/JS_rew.png', bbox_inches='tight', dpi=150)
    plt.close()

    # forward KL
    fig, ax = plt.subplots(1)
    x = np.arange(-10,10,0.05)
    y = np.exp(x) * (-x)
    ax.plot(x, y, color=line_color, linewidth=plot_line_width)
    ax.set_xlim([-10,10])
    ax.set_ylim([-2,0.5])
    ax.set_xlabel(r'log$\frac{\rho^{exp}(s,a)}{\rho^\pi(s,a)}$', fontsize='xx-large')
    ax.set_ylabel('$r(s,a)$', fontsize='xx-large')
    plt.axhline(0, color='grey')
    plt.axvline(0, color='grey')
    plt.savefig('plots/junk_vis/forw_KL_rew.png', bbox_inches='tight', dpi=150)
    plt.close()


def _sample_color_within_radius(center, radius):
    x = np.random.normal(size=2)
    x /= np.linalg.norm(x, axis=-1)
    r = radius
    u = np.random.uniform()
    sampled_color = r * (u**0.5) * x + center
    return np.clip(sampled_color, -1.0, 1.0)


def _sample_color_with_min_dist(color, min_dist):
    new_color = np.random.uniform(-1.0, 1.0, size=2)
    while np.linalg.norm(new_color - color, axis=-1) < min_dist:
        new_color = np.random.uniform(-1.0, 1.0, size=2)
    return new_color


def visualize_multi_ant_target_percentages(csv_array, num_targets, title='', save_path=''):
    # almost rainbow :P
    colors = [
        'purple', 'cyan', 'blue', 'green', 'yellow', 'orange', 'pink', 'red'
    ]

    # gather the results
    all_perc = [
        csv_array['Target_%d_Perc'%i] for i in range(num_targets)
    ]
    all_dist = np.array([
        csv_array['Target_%d_Dist_Mean'%i] for i in range(num_targets)
    ])
    all_dist[all_dist == -1] = 0.0
    all_dist = np.sum(all_dist * all_perc, axis=0)


    for i in range(1, len(all_perc)):
        all_perc[i] = all_perc[i] + all_perc[i-1]


    X = np.arange(all_dist.shape[0])

    # time to plot
    plt.subplot(2, 1, 1)
    alpha = 0.5
    plt.fill_between(X, all_perc[0], color=colors[0], alpha=alpha)
    for i in range(1, len(all_perc)):
        plt.fill_between(X, all_perc[i], y2=all_perc[i-1], color=colors[i], alpha=alpha)
    # plt.xlabel('epoch')
    plt.ylabel('Perc. Each Target')
    plt.title(title)

    plt.subplot(2, 1, 2)
    plt.plot(X, all_dist, color='royalblue')
    plt.ylim((0.0, 1.5))
    plt.xlabel('epoch')
    plt.ylabel('Dist. to Closest Target')

    if save_path == '':
        plt.savefig('plots/junk_vis/test_multi_ant_plot.png', bbox_inches='tight', dpi=150)
    else:
        plt.savefig(save_path, bbox_inches='tight', dpi=150)
    plt.close()


def visualize_multi_ant_target_percentages_v2(csv_array, num_targets, title='', save_path=''):
    # almost rainbow :P
    colors = [
        'purple', 'cyan', 'blue', 'green', 'yellow', 'orange', 'pink', 'red'
    ]

    # gather the results
    all_perc = [
        csv_array['Target_%d_Perc'%i] for i in range(num_targets)
    ]
    all_dist = np.array([
        csv_array['Target_%d_Dist_Mean'%i] for i in range(num_targets)
    ])
    all_dist[all_dist == -1] = 0.0
    # all_dist = np.sum(all_dist * all_perc, axis=0)


    for i in range(1, len(all_perc)):
        all_perc[i] = all_perc[i] + all_perc[i-1]


    X = np.arange(all_dist[0].shape[0])

    # time to plot
    plt.subplot(num_targets+1, 1, 1)
    alpha = 0.5
    plt.fill_between(X, all_perc[0], color=colors[0], alpha=alpha)
    for i in range(1, len(all_perc)):
        plt.fill_between(X, all_perc[i], y2=all_perc[i-1], color=colors[i], alpha=alpha)
    # plt.xlabel('epoch')
    plt.ylabel('Perc. Each Target')
    plt.title(title)

    for i in range(num_targets):
        plt.subplot(num_targets+1, 1, i+2)
        plt.plot(X, all_dist[i], color=colors[i])
        plt.ylim((0.0, 3.5))
        plt.xlabel('epoch')
        plt.ylabel('Dist. to Closest Target')

    if save_path == '':
        plt.savefig('plots/junk_vis/test_multi_ant_plot.png', bbox_inches='tight', dpi=150)
    else:
        plt.savefig(save_path, bbox_inches='tight', dpi=150)
    plt.close()


def plot_fetch_pedagogical_example():
    import shapely.geometry as sg
    import descartes

    print(np.random.get_state())

    M = sg.Polygon([(-1.0, -1.0), (-1.0, 1.0), (1.0, 1.0), (1.0, -1.0)])

    v = np.random.uniform(-1.0, 1.0, size=2)
    good_inter = None
    bad_union = None
    good_circles = []
    bad_circles = []
    for i in range(6):
        u = _sample_color_within_radius(v, 0.5)
        f = _sample_color_with_min_dist(v, 0.5)

        good_c = sg.Point(*u).buffer(0.5)
        good_c = good_c.intersection(M)
        good_circles.append(good_c)
        bad_c = sg.Point(*f).buffer(0.5)
        bad_c = bad_c.intersection(M)
        bad_circles.append(bad_c)
        
        if good_inter is None:
            good_inter = good_c
        else:
            good_inter = good_inter.intersection(good_c)
        
        if bad_union is None:
            bad_union = bad_c
        else:
            bad_union = bad_union.union(bad_c)
        
        fig, ax = plt.subplots(1)
        for g in good_circles:
            ax.add_patch(descartes.PolygonPatch(g, fc='green', ec='green', alpha=0.1))
        for b in bad_circles:
            ax.add_patch(descartes.PolygonPatch(b, fc='pink', ec='pink', alpha=0.3))
        # ax.add_patch(descartes.PolygonPatch(bad_union, fc='pink', ec='pink', alpha=0.3))
        ax.add_patch(descartes.PolygonPatch(good_inter.difference(bad_union), fc='green', ec='green', alpha=1.0))

        plt.plot([v[0]], [v[1]], marker='*', markeredgecolor='gold', markerfacecolor='gold', markersize=20.0)
        # markersize
        
        ax.set_xlim([-1.0,1.0])
        ax.set_ylim([-1.0,1.0])
        ax.set_aspect('equal')
        plt.savefig('plots/junk_vis/fetch/img_%d.png'%i, bbox_inches='tight', dpi=150)
        plt.close()



        # ax.set_xlabel(r'log$\frac{\rho^{exp}(s,a)}{\rho^\pi(s,a)}$', fontsize='xx-large')
        # ax.set_ylabel('$r(s,a)$', fontsize='xx-large')
        # plt.axhline(0, color='grey')
        # plt.axvline(0, color='grey')
        

    # # create the circles with shapely
    # a = sg.Point(-.5,0).buffer(1.)
    # b = sg.Point(0.5,0).buffer(1.)

    # # compute the 3 parts
    # left = a.difference(b)
    # right = b.difference(a)
    # middle = a.intersection(b)

    # # use descartes to create the matplotlib patches
    # ax = plt.gca()
    # ax.add_patch(descartes.PolygonPatch(left, fc='b', ec='k', alpha=0.2))
    # ax.add_patch(descartes.PolygonPatch(right, fc='r', ec='k', alpha=0.2))
    # ax.add_patch(descartes.PolygonPatch(middle, fc='g', ec='k', alpha=0.2))

    # # control display
    # ax.set_xlim(-2, 2); ax.set_ylim(-2, 2)
    
    # plt.show()

if __name__ == '__main__':
    plot_forward_reverse_KL_rews()
