import seaborn as sns
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.transforms as mtransforms


def get_pal(ncolors, npercolor, plot=False):
    base_colors = sns.color_palette("deep")[:ncolors]
    n_off = npercolor // 3.
    pal = np.vstack([sns.light_palette(c, npercolor+n_off, reverse=True)[:npercolor] for c in base_colors])
    sns.set_palette(pal)
    if plot:
        sns.palplot(pal)
    return pal

def add_label(label, xoff=-0.1, yoff=1.3):
    ax = plt.gca()
    ax.text(xoff, yoff, '%s'%label, transform=ax.transAxes,
      fontsize=12, fontweight='bold', va='top', ha='right')

def pcolor(bias_sigmas, weight_sigmas, q, vmin, vmax):
    """Version of pcolor that removes edges"""
    h = plt.pcolormesh(bias_sigmas, weight_sigmas, q,
                        cmap=plt.cm.plasma, vmin=vmin, vmax=vmax, shading='auto')
    h.set_edgecolor('face')
    return h

def sigma_pcolor(q, weight_sigmas, bias_sigmas, draw_colorbar=True, **kwargs):
    if np.isnan(q).any():
        q[np.isnan(q)]= np.nanmax(q)
    q = np.log10(q)
    vmax = int(np.ceil(np.max(q)))
    vmin = int(np.floor(np.min(q)))
    pcolor(bias_sigmas, weight_sigmas, q, vmin=vmin, vmax=vmax)
    if draw_colorbar:
        cbar = plt.colorbar(ticks=(vmin, (vmin+vmax)/2.0, vmax))
        cbar.ax.set_yticklabels([f'$10^{{{vmin}}}$',
                    f'$10^{{{(vmin+vmax)/2.0}}}$', f'$10^{{{vmax}}}$'], fontsize=40)
# https://medium.com/@BrendanArtley/matplotlib-color-gradients-21374910584b

def hex_to_RGB(hex_str):
    """ #FFFFFF -> [255,255,255]"""
    #Pass 16 to the integer function for change of base
    return [int(hex_str[i:i+2], 16) for i in range(1,6,2)]

def get_color_gradient(c1, c2, n):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    # assert n > 1
    c1_rgb = np.array(hex_to_RGB(c1))/255
    c2_rgb = np.array(hex_to_RGB(c2))/255
    mix_pcts = [x/(n-1) for x in range(n)] if n > 1 else [1]
    rgb_colors = [((1-mix)*c1_rgb + (mix*c2_rgb)) for mix in mix_pcts]
    return ["#" + "".join([format(int(round(val*255)), "02x") for val in item]) for item in rgb_colors]

# https://github.com/ganguli-lab/deepchaos/blob/master/notebooks/Figures%201%20%26%202%20-%20qmap%20and%20cmap.ipynb

def plot_qmaps(ax, qrange, qmaps, widxs, bidxs, vmax=15, **plot_kwargs):
    for ii, (widx, bidx) in enumerate(zip(widxs, bidxs)):
        if qmaps.ndim == 4:
            for qidx in range(qmaps.shape[-2]):
                ax.plot(qrange, qmaps[widx, bidx, qidx], **plot_kwargs)
        else:
            ax.plot(qrange, qmaps[widx, bidx], **plot_kwargs)
    vmax = 15
    ax.set_xticks([0, 5, 10, 15])
    ax.set_yticks([0, 5, 10, 15])
    ax.set_xlim(0, vmax); ax.set_ylim(0, vmax)
    ax.set_xlabel('input length ($q^{l-1})$')
    ax.set_ylabel('output length ($q^l$)')

    def set_plot(ax, p_cond, crop=0,
                    Ts=None, fin_losses_mu=None, final_losses_std=None):
        if 'total' not in p_cond:
            set_plot_layer(ax, p_cond, var='sigma_w')
        else:
            set_plot_total(ax, Ts, fin_losses_mu, final_losses_std, crop)

# Add labels, title, and legend
def set_plot_layer(args, ax, p_cond, var='sigma_w'):
    name = args.name
    options = ['sigma_w', 'T', 'eta']
    # initialize title, savefig as f string
    title = f'{name}\'s length of ' if p_cond != 'loss layer' else f'{name}\'s '
    # ax.set_xlabel('Layer')
    ax.tick_params(axis='x', labelsize=40)
    ax.tick_params(axis='y', labelsize=40)
    ax.set_yscale('log')
    if p_cond =='W_grad':
        # ax.set_yscale('log')
        # ax.set_ylabel('length of W_grad')
        # ax.legend(loc='center right', bbox_to_anchor=(1.0, -0.5))
        title += f'W_grad per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond =='b_grad':
        # ax.set_yscale('log')
        # ax.set_ylabel('length of b_grad')
        # ax.legend(loc='center right', bbox_to_anchor=(1.0, -0.5))
        title += f'W_grad per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond == 'delta':
        # ax.set_ylabel('length of delta')
        # ax.legend(loc='upper right')
        title += f'delta per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond == 'loss_layer':
        # ax.set_ylabel('layer loss')
        # ax.legend(loc='upper right')
        title += f'layer loss per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond == 'latent':
        # ax.set_ylabel('length of latent')
        # ax.legend(loc='center right')
        title += f'latent per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond == 'latent_pre':
        # ax.set_ylabel('length of preactivated latent')
        # ax.legend(loc='upper right')
        title += f'preactivated latent per {var}(min{args.run_min}max{getattr(args, var)})'
    elif p_cond == 'latent_prd':
        # ax.set_ylabel('length of predicted latent')
        # ax.legend(loc='upper right')
        title += f'predicted latent per {var}(min{args.run_min}max{getattr(args, var)})'
    for option in options:
        if option != var and name == 'PC':
            title += f', {option}={getattr(args, option)}'
    title += f', o2c_idx={args.o2c_idx}' if args.dataset == 'o2c' else ''
    # ax.set_title(title)

def set_plot_total(args, ax, Ts, fin_losses_mu, final_losses_std, crop):
    ax.set_xlabel('Iteration')
    ax.set_ylabel('Total loss')
    ax.set_yscale('log')
    cropped = f'(cropped{crop})' if crop > 0 else ''
    s1 = f'Total loss per iteration {cropped}: $\eta$ = {args.eta},'
    s2 = f' $\sigma_w$ = {args.sigma_w}, schedule = [50,100,...,800]'
    ax.set_title(s1+s2)
    # ax.legend()
    # plt.tight_layout()
    plt.savefig(f'figures/tot_loss{cropped}_sigma_w={args.sigma_w}.png')
    # new plot of final loss
    fig, ax = plt.subplots()
    ax.bar(Ts, fin_losses_mu, yerr=final_losses_std, capsize=5)
    # ax.set_xlabel('Total update steps T')
    # ax.set_ylabel('Final loss')
    ax.set_title(f'Final loss per total update steps T, $\sigma_w$ = {args.sigma_w}')
    # plt.tight_layout()
    plt.savefig(f'figures/final_loss_sigma_w={args.sigma_w}.png')
    plt.cla()
    plt.clf()
    plt.close()


