import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.colors import LogNorm, Normalize, SymLogNorm
import matplotlib.cm as cm
import matplotlib as mpl
import numpy as np
import itertools
from seaborn import heatmap
import wandb
import math
import torch

COLORS = ["blue", "red", "#FAA402", "#118141"]
MARKERS = (".","v","o",",","^","<",">","1","2","3","4","8","s",
                              "p","P","*","h","H","x","X","D","d","|","_")
RC_PARAMS = {
        'font.size': 22,
        "text.usetex": False,
        # "font.family": "Times",
        "font.family": "DejaVu Sans",
        "font.sans-serif": ["Helvetica"]}
LEGEND_LINE_WIDTH = 3
PLOT_LINE_WIDTH = 3
MARKER_SIZE = 15
HEATMAP_ANNOT_FONT_SZ = 5
HEATMAP_TICK_LABEL_SIZE = 7
FIGSIZE = (10, 8)

mpl.rcParams.update(mpl.rcParamsDefault)

def plot_realfc_vs_time(path, loss_list, label_list, x_units = None,
                        error_bars=False, extra_title_info="", x_label="", y_label="", y_scale='log', show=False):
    plt.rcParams.update(RC_PARAMS)
    legend_line_width = LEGEND_LINE_WIDTH
    plot_line_width = PLOT_LINE_WIDTH
    marker_size = MARKER_SIZE
    marker = itertools.cycle(MARKERS)
    fig = plt.figure(figsize=FIGSIZE)

    plt.yscale(y_scale)
    plt.grid(True, which="both", ls="-")
    if x_units is None:
        x_units = np.array(range(loss_list[0].shape[-1]))
    
    #### plot each line in the list separately
    if not error_bars:
        assert len(loss_list) == len(label_list), "There must be as many labels as losses when error_bars == False."
        
        for i in range(len(label_list)):
            loss = loss_list[i]
            
            if isinstance(loss, torch.Tensor):
                loss = loss.detach().cpu().numpy()

            marker_freq = max(1, loss.shape[0] // 20)
            if i <= 3:
                plt.plot(x_units, loss, linewidth = plot_line_width, label=label_list[i], color = COLORS[i],
                        marker=next(marker), markevery=marker_freq, markersize=marker_size)
            else:
                plt.plot(x_units, loss, linewidth=plot_line_width, label=label_list[i],
                        marker=next(marker), markevery=marker_freq, markersize=marker_size)
    
            leg_lines = plt.legend().get_lines()
            for i in range(len(leg_lines)):
                plt.setp(leg_lines[i], linewidth=legend_line_width)

    else: 
        #### treat the lines as belonging to the same experiment and thus providing stdev margins
        losses = np.zeros((len(loss_list), loss_list[0].shape[0]))
        for idx, loss in enumerate(loss_list):
            losses[idx,:] = loss
        losses_mean = np.mean(losses, axis=0)
        losses_std = np.std(losses, axis=0)
        plt.plot(x_units, losses_mean, color = 'red', lw = 3)
        plt.fill_between(x_units, losses_mean-losses_std, losses_mean+losses_std, color = 'red', alpha = 0.1)
        plt.tick_params(axis='both', which='major', labelsize=30, width = 3, length = 10)
        plt.tick_params(axis='both', which='minor', labelsize=20, width = 3, length = 5)


    plt.ylabel(y_label)
    plt.xlabel(x_label)
    plt.tight_layout()
    plt.grid()
    plt.savefig(path + "results_" + extra_title_info+ ".pdf", bbox_inches='tight')
    if show:
        plt.show()

    return fig

def plot_heatmap(matrix_list, label_list, file_path):

    assert len(label_list) == len(matrix_list), "Error -- matrix and label list have diff. lenghts."

    plt.rcParams.update(RC_PARAMS)
    fig, ax = plt.subplots(figsize = FIGSIZE)
    
    vmin, vmax = 0., 1.
    normalize = mcolors.Normalize(vmin=vmin, vmax=vmax)

    num_plots = len(label_list)

    for k in range(num_plots):
        cmap = cm.gray_r
        im = ax.imshow(torch.abs(matrix_list[k]), cmap=cmap, norm=normalize)
        for i in range(matrix_list[k].shape[0]):
            for j in range(matrix_list[k].shape[1]):
                ax.text(j, i, format(matrix_list[k][i, j], '.2f'), ha='center', va='center', color='r')

        scalarmappaple = cm.ScalarMappable(cmap=cmap, norm=normalize)
        scalarmappaple.set_array(np.abs(matrix_list[k]))
        # fig.colorbar(scalarmappaple, ax=ax)

        # Add a colorbar for reference
        nrows = matrix_list[k].shape[0]
        ncols = matrix_list[k].shape[-1]
        fig.colorbar(im)
        ax.set_yticks(np.arange(nrows))
        ax.set_yticklabels(np.arange(1, nrows + 1))
        ax.set_xticks(np.arange(ncols))
        ax.set_xticklabels(np.arange(1, ncols + 1))
        #ax.set_title('${}$'.format(label_list[k]))
        plt.savefig(file_path + '/ICL_dynsys_{}.pdf'.format(label_list[k]), dpi=600)



    

def plot_errs(names, err_lss, legend_loc="upper right", ax=None, shade=True):
    if ax is None:
        fig = plt.figure(figsize=(15, 9))
        ax = fig.add_subplot(111)
    ax.set_yscale('log')
    ax.grid()
    handles = []
    for i, (name, err_ls) in enumerate(zip(names, err_lss)):
        traj_errs = err_ls.sum(axis=1)
        print("Sequence-sum of prediction error of " + name, "{:.2f}".format(traj_errs.mean()))
        avg, std = err_ls.mean(axis=0), err_ls.std(axis=0)
        print("Final prediction error of " + name, "{:.2f}".format(avg[-1]))
        handles.extend(ax.plot(avg, label=name, linewidth=3))
        if shade:
            ax.fill_between(range(len(avg)), avg-std, 
                            avg+std, facecolor=handles[-1].get_color(), alpha=0.33)
        
    ax.legend(fontsize=30, loc=legend_loc)
    ax.set_ylabel("Prediction Error", fontsize=30)
    ax.set_xlabel("t", fontsize=30)
    ax.grid(which="both")
    ax.tick_params(axis='both', which='major', labelsize=30)
    ax.tick_params(axis='both', which='minor', labelsize=20)

def plot_2d_curves(curve_list, label_list, path, extra_title_info=""):
    plt.rcParams.update(RC_PARAMS)
    legend_line_width=LEGEND_LINE_WIDTH 
    plot_line_width=PLOT_LINE_WIDTH
    marker_size = MARKER_SIZE*15
    marker = itertools.cycle(MARKERS)
    plt.figure(figsize=FIGSIZE)

    # plt.yscale('log')
    # plt.xscale('log')
    for i in range(len(label_list)):
        path_x = curve_list[i][:, 0]
        path_y = curve_list[i][:, 1]

        # defining arbitrary parameter to parameterize the curve
        path_t = np.linspace(0, 1, len(path_x))

        # this is the position vector with
        # x coord (1st row) given by path_x, and
        # y coord (2nd row) given by path_y
        r = np.vstack((path_x.reshape((1,len(path_x))),path_y.reshape((1, len(path_y)))))

        # creating the spline object
        spline = scipy.interpolate.interp1d(path_t,r,kind='cubic')

        # defining values of the arbitrary parameter over which you want to interpolate x and y
        # it MUST be within 0 and 1, since you defined the spline between path_t=0 and path_t=1
        t = np.linspace(np.min(path_t),np.max(path_t),100)

        # interpolating along t
        # r[0,:] -> interpolated x coordinates
        # r[1,:] -> interpolated y coordinates
        r = spline(t)

    
        plt.scatter(path_x, path_y, marker=next(marker), color = COLORS[i], s=marker_size)
        plt.plot(r[0,:],r[1,:],linewidth = plot_line_width, label=label_list[i], color = COLORS[i])

    plt.ylabel(r'$y$')
    plt.xlabel(r'$x$')

    leg_lines = plt.legend().get_lines()
    for i in range(len(leg_lines)):
        plt.setp(leg_lines[i], linewidth=legend_line_width)

    plt.tight_layout()
    plt.grid()
    plt.savefig(path + "results_" + extra_title_info+ ".pdf", bbox_inches='tight')
    plt.show()



def project_and_plot_high_dim_curves(curve_a, curve_b, OUT_FOLDER):
    #TODO: test this cr*p
    for i in range(len(curve_a)):
        
        # pick 2 indices uniformly at random
        (dim1, dim2) = np.random.randint(0, curve_a[0].shape[1], size = (2,))
        
        true_c = curve_a[i][:, (dim1, dim2)]
        pred_c = curve_b[i][:, (dim1, dim2)]
        
        plot_2d_curves([true_c, pred_c], ["true", "pred"], OUT_FOLDER + "/" + "traj_", extra_title_info="")





################################################################################
def wandb_log_heatmap(wandb_writer, matrix, min_matrix, max_matrix, title, square=True, 
                      norm=None, annot=False, step=None, special_step=None, special_step_name=None):
    
    if matrix is None:
        return 
    
    assert len(matrix.shape) >= 2, "This fc only plots heatmaps or arrays of heatmaps (2 & 3 dimensions)." 
    
    if len(matrix.shape) == 2:
        # TODO: this torch.tensor is a workaround for dealing w the fact that we're being returned numpy.
        # change this to a proper solution!
        matrix = torch.tensor(matrix).unsqueeze(0)

    # If matrix is 3 dimensional, then we need subplots. The first dimension is number of subplots
    num_rows_cols = 1 if len(matrix.shape) == 2 else int(math.ceil(math.sqrt(matrix.shape[0])))
    nrm = norm if norm is not None else Normalize()

    fig, axs = plt.subplots(figsize=FIGSIZE, nrows=num_rows_cols, ncols=num_rows_cols)

    # Generate heatmap
    if num_rows_cols > 1:
        for idx1, row in enumerate(axs):
            for idx2, col in enumerate(row):
                mat_idx = idx1 * num_rows_cols + idx2
                if mat_idx < matrix.shape[0]:
                    # for W_Q^h, W_K^h, W_V^h, the dimension of the variable "matrix" is [h, d, d]
                    # TODO check why do I have vmin and max in heatmap set to -1 and 1
                    if min_matrix is not None:
                        nrm.vmin = min_matrix[mat_idx]
                        nrm.vmax = max_matrix[mat_idx]
                    hm = heatmap(matrix[mat_idx], square=square, norm=nrm, annot=annot, 
                                 annot_kws={"size": HEATMAP_ANNOT_FONT_SZ},
                                 ax=axs[idx1, idx2], cmap='coolwarm', vmin=1, vmax=-1) #cmap='PuOr', 'twilight', 'Greys', 'seismic'
                    
                    # Display properties
                    axs[idx1, idx2].set_xticklabels(axs[idx1, idx2].get_xticklabels(), fontsize=HEATMAP_TICK_LABEL_SIZE)
                    axs[idx1, idx2].set_yticklabels(axs[idx1, idx2].get_yticklabels(), fontsize=HEATMAP_TICK_LABEL_SIZE)
                    cbar = hm.collections[0].colorbar
                    cbar.ax.tick_params(labelsize=HEATMAP_TICK_LABEL_SIZE)

                else:
                    axs[idx1, idx2].set_visible(False)

    else:
        if min_matrix is not None:
            nrm.vmin = min_matrix
            nrm.vmax = max_matrix
        hm = heatmap(matrix[0], square=square, norm=nrm, annot=annot, annot_kws={"size": HEATMAP_ANNOT_FONT_SZ},
                ax=axs, cmap='coolwarm', vmin=1, vmax=-1) #cmap='PuOr', 'twilight', 'Greys', 'seismic' 
        
        axs.set_xticklabels(axs.get_xticklabels(), fontsize=HEATMAP_TICK_LABEL_SIZE)
        axs.set_yticklabels(axs.get_yticklabels(), fontsize=HEATMAP_TICK_LABEL_SIZE)
                    
        cbar = hm.collections[0].colorbar
        cbar.ax.tick_params(labelsize=HEATMAP_TICK_LABEL_SIZE)

    

    # create logging info
    info = {title: wandb.Image(fig)}
    if special_step != None:
        assert special_step_name != None
        info.update({special_step_name: special_step})
    
    if step == None:
        wandb_writer.log(info)
    else:
        wandb_writer.log(info, step=step)

    # plt.show()
    plt.close()

    