"""Common methods for visualizing data & metrics."""
import itertools
import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numpy as np
import os

_FIGSIZE = (8, 4)

def plot_traces(xvals, traces_to_plot, labels, title, xlabel, ylabel,
                markers=[], colors=[], linestyles=[], fig=None):
  """Plots multiple traces on the same figure.

  Args:
    xvals: list of np.ndarray (or list). Each element of this list is the x-axis
      value corresponding to a particular trace. Must be a list of the same length
      as traces_to_plot.
    traces_to_plot: list of np.ndarray (or list). Each element of this list is the
      y-axis value corresponding to a particular trace. Must be a list of the same
      length as xvals.
    labels: list of str. Labels for each trace.
    title: str. Title of figure.
    xlabel: str. X-axis label of figure.
    ylabel: str. Y-axis label of figure.
    markers: Optional. list of str. Corresponds to marker types to use for each
      trace.
    colors: Optional. list of str. Corresponds to line/marker color to use for
      each trace.
    linestyles: Optional. list of str. Corresponds to the linestyles to use for
      each trace.
    fig: Optional. matplotlib.figure.Figure. If none provided will automatically
      create a new figure, otherwise will plot traces on provided figure.

  Returns:
    Figure with plots.
  """
  if fig is None:
    fig = plt.figure(figsize=_FIGSIZE)

  for i in range(len(traces_to_plot)):
    kwargs = {}
    if markers:
      kwargs['marker'] = markers[i]
    if colors:
      kwargs['color'] = colors[i]
    if linestyles:
      kwargs['linestyle'] = linestyles[i]
    plt.plot(xvals[i], traces_to_plot[i], label=labels[i], **kwargs)

  plt.legend()
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  plt.title(title)
  return fig

def plot_errorbars(data, xlabels, ylabel, title, colors=[],
                   fig=None, figsize=(3,3)):
  if len(xlabels) != len(data):
    raise ValueError('Data and labels must be the same length. ')

  if not fig:
    fig = plt.figure(figsize=figsize)

  if not colors:
    clr_whl = cycle_colors()
    colors = [next(clr_whl) for _ in range(len(data))]

  x_values, this_x = [], 0
  for data_ind, this_data in enumerate(data):
    data_mean = np.mean(this_data)
    data_sem = np.std(this_data) / np.sqrt(len(this_data))
    plt.errorbar(this_x, data_mean, data_sem, marker='_',
                 markersize=10, color=colors[data_ind], elinewidth=1, lw=2)
    x_values.append(this_x)
    this_x += 0.1

  plt.title(title)
  plt.ylabel(ylabel); plt.xticks(x_values, labels=xlabels)
  return fig

def plot_eigenvalues(eigs, title, labels=['true', 'id'],
                     markers=['o', 'x'],
                     edgecolors=['#0000ff', 'None'],
                     facecolors=['None', '#00aa00'],
                     alpha_vals=[0.5, 0.5], bound_lims_to_circle=False,
                     show_legend=True, legend_location='lower_center',
                     fig=None, ax=None):
  """Plots eigenvalues on z-transform unit circle."""
  if fig is None and ax is None:
    fig = plt.figure(figsize=_FIGSIZE)
    axs = fig.subplots(1, 2)
    axs[1].remove()
    ax = axs[0]

  ax.axis('equal')
  ax.add_patch(patches.Circle((0,0), radius=1, fill=False, color='black',
               alpha=0.2, ls='-') )
  ax.plot([-1,1,0,0,0], [0,0,0,-1,1], color='black', alpha=0.2, ls='-')

  # Adjust alpha length automatically. Gracefully handles if user doesn't input
  # multiple alpha values.
  if len(alpha_vals) != len(eigs):
    for _ in range(len(eigs) - len(alpha_vals)):
      alpha_vals.append(alpha_vals[0])

  for i in range(len(eigs)):
    eig, label, marker = eigs[i], labels[i], markers[i]
    edgecolor, facecolor, alpha = edgecolors[i], facecolors[i], alpha_vals[i]
    ax.scatter(np.real(eig), np.imag(eig), marker=marker, edgecolors=edgecolor,
               facecolors=facecolor, alpha=alpha, label=label)

  if show_legend:
    ax.legend(loc=legend_location)
  ax.set_title(title)
  if bound_lims_to_circle:
    ax.set_xlim([-1.1, 1.1])
    ax.set_ylim([-1.1, 1.1])
  return fig

def plot_sweeping_eigenvalues(gt_eigs, id_eigs_per_nx, algo_label,
                              brev_gt_eigs=None, unencoded_eigs=None,
                              nx_to_plot=[], mode_colors={}):
  """Plots multiple sets of eigenvalues in the same matplotlib subplots figure.
  
  Typical usage is to plot a series of unit circles with the identified poles as
  a function of some parameter (typically the order of the latent process).

  Args:

  Returns:
  """
  if not isinstance(id_eigs_per_nx, list):
    id_eigs_per_nx = [id_eigs_per_nx]
  if not isinstance(algo_label, list):
    algo_label = [algo_label]

  if not nx_to_plot:
    nx_to_plot = range(en(gt_eigs) + len(brev_gt_eigs))

  if not mode_colors:
    mode_colors = {'default': 'deepskyblue', 'brev': 'indigo',
                   'unencoded': 'lightcoral', 'id': []}
    possible_id_colors = viz_utils.cycle_colors()
    for _ in range(len(id_eigs_per_nx)):
      mode_colors['id'].append(next(possible_id_colors))
  else:
    if not isinstance(mode_colors['id'], list):
      mode_colors['id'] = [mode_colors['id']]

  num_cols = len(nx_to_plot)
  fig, ax = plt.subplots(1, num_cols, figsize=(10, int(50 // num_cols)))
  eigs_to_viz = [gt_eigs]
  if brev_gt_eigs is not None:
    labels = ['Behaviorally irrelevant', 'Behaviorally relevant']
    markers = ['o', 'o']
    alpha_vals = [0.3, 0.9]
    edgecolors = [mode_colors['default'], mode_colors['brev']]
    eigs_to_viz.append(brev_gt_eigs)
    if unencoded_eigs is not None:
      labels.append('Unencoded in neural activity')
      markers.append('o')
      alpha_vals.append(0.3)
      edgecolors.append(mode_colors['unencoded'])
      eigs_to_viz.append(unencoded_eigs)
  else: # not brev_gt_eigs
    labels = ['True eigenvalues']
    markers = ['o']
    edge_colors = [mode_colors['default']]
    alpha_vals = [0.7]
  facecolors = ['w'] * len(markers)

  for ind in range(len(id_eigs_per_nx)):
    labels.append('{0}'.format(algo_label[ind]))
    markers.append('x')
    edgecolors.append(None)
    facecolors.append(mode_colors['id'][ind])
    alpha_vals.append(0.9)

  for ind, nx in enumerate(nx_to_plot):
    id_eigs_agg = []
    for this_algo_eigs_per_nx in id_eigs_per_nx:
      id_eigs_agg.append(this_algo_eigs_per_nx[nx])
    show_legend = (nx == nx_to_plot[-1])
    plot_eigenvalues(eigs_to_viz + id_eigs_agg, 'nx={0}'.format(nx), markers=markers,
                     edgecolors=edgecolors, facecolors=facecolors, labels=labels,
                     alpha_vals=alpha_vals, show_legend=show_legend,
                     bound_lims_to_circle=True, ax=ax[ind], fig=fig)
    ax[ind].axis('off')
  return fig

def plot_average_w_stderr(x_vals, y_vals, xlabel, ylabel,
                          label=None, fig=None, color='b',
                          marker='.', linestyle='-', alpha=1,
                          y_vals_ind=None):
  """Trendlines plot of mean with standard error shading.

  Args:
    x_vals: list. Independent variable values. Visualized on x-axis.
    y_vals: list of list. Length of outer list should be the same length as
      x_vals. Inner lists are the values of interest per x_val condition.
      These values will be averaged for the plots and standard error will be
      computed over these values.
    xlabel:
    ylabel:
    label:
    fig:
    color:
    marker:
    linestyle:
    y_vals_ind:
  """
  x_vals = np.array(x_vals)
  sorted_inds = np.argsort(x_vals)
  x_vals = x_vals[sorted_inds]
  means, stds, total_counts = [], [], []
  for ind in sorted_inds:
    val = y_vals[ind]
    if y_vals_ind: # Plot time series of a particular variable in matrix.
      val = val[:, y_vals_ind]
    total_counts.append(len(val))
    means.append(np.mean(np.nan_to_num(val)))
    stds.append(np.std(np.nan_to_num(val)))
  means = np.array(means)
  # https://en.wikipedia.org/wiki/Standard_error
  stderr = np.divide(np.array(stds), np.sqrt(np.array(total_counts)))

  if fig is None:
    fig = plt.figure(figsize=_FIGSIZE)
  plt.title('Average {0} over all {1}'.format(ylabel, xlabel))
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  full_label = '{0} {1}'.format(label, total_counts)
  plt.plot(x_vals, means, color=color, label=full_label, marker=marker,
           linestyle=linestyle, alpha=alpha)
  plt.fill_between(x_vals, means-stderr, means+stderr, color=color, alpha=0.2)
  return fig

def plot_comparison_scatter(x_data, y_data, title, xlabel, ylabel, 
                            fig=None, gen_line=True, **kwargs):
  if fig is None:
    fig = plt.figure(figsize=_FIGSIZE)
  if gen_line:
    mincc = min(np.min(x_data), np.min(y_data))
    start = -0.1 + mincc
    t = np.linspace(start, 1) # needed to plot the results
    plt.plot(t, t, alpha=0.5, c='g', linewidth=2)

  plt.scatter(x_data, y_data, **kwargs)
  plt.title(title)
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  return fig

def raster_plot(spikes, delta_t, t=None, tunits='ms'):
  """Generates a raster plot of spiking activity for multiple channels.
  
  Args:
    spikes: np.ndarray of shape (neurons, time). Spiking activity.
    delta_t: float. Time corresponding to each bin.
    t: Optional. np.ndarray. Corresponds to the values to use on the x-axis.
      Default is None meaning time axis will start at 0.
    tunits: Optional. string. Time units used for labeling. Default 'ms'.

  Returns:
    Figure corresponding to raster plot.
  """
  fig = plt.figure(figsize=(20, 5))
  num_neurons, num_samples = spikes.shape
  if t is None:
    t = np.arange(num_samples) * delta_t
  offset = 0
  for i in range(num_neurons):
    inds = np.where(spikes[i, :] == 1)[0]
    plt.eventplot(t[inds], lineoffsets=offset)
    offset += 2
  plt.xlabel('time {0}'.format(tunits))
  plt.ylabel('neuron')
  return fig
    
def plot_2d(data, title, figsize=(20,5)):
  """Plot 2D images, either multiple side-by-side or a single plot."""
  if not isinstance(data, list):
    fig = plt.figure(figsize=figsize)
    plt.imshow(data)
    plt.colorbar()
    plt.title(title)
    fig.gca().set_aspect('equal')
    return fig
  
  if not isinstance(title, list):
    title = [title] * len(data)
  
  fig, axs = plt.subplots(1, len(data), figsize=figsize)
  for i in range(len(data)):
    ax_map = axs[i].imshow(data[i])
    fig.colorbar(ax_map, ax=axs[i], shrink=0.3)
    axs[i].set_title(title[i])
    axs[i].set_aspect('equal')
  return fig

def plot_boxplots(data, labels, title, ylabel, fig=None):
  if fig is None:
    fig = plt.figure()
  plt.boxplot(data, labels=labels)
  plt.ylabel(ylabel)
  plt.title(title)
  return fig

def plot_multi_box(x_ticks, y_vals, category_labels,
                   title, xlabel, ylabel, fig=None):
  """
  x_ticks: names for the xticks
  y_vals: list of values per category. specifically y_vals should be a list of
    lists (or numpy arrays), such that it is a list of length
    len(category_labels) and each element of y_vals is of length len(x_labels)
  category_labels: list of labels for the categories to visualize.
  title:
  xlabel:
  ylabel:
  fig:
  """
  if fig is None:
    fig = plt.figure()

  X_axis = np.arange(len(x_ticks))
  _PLT_WIDTH = 0.8
  num_categories = len(category_labels)
  delta_per_category = _PLT_WIDTH / num_categories

  for num_cat, category in enumerate(category_labels):
    plt.bar(X_axis + delta_per_category * num_cat,
            y_vals[num_cat],
            width=delta_per_category, align='edge', label=category)
  
  plt.legend()
  plt.title(title)
  plt.xticks(X_axis, x_ticks)
  plt.xlabel(xlabel)
  plt.ylabel(ylabel)
  return fig

def plot_scatter_boxplots(data, xlabels, ylabel, title, colors=[], showfliers=False,
                          whis=1.5, fig=None, figsize=(20,7)):
  if len(xlabels) != len(data):
    raise ValueError('Data and labels must be the same length.')

  if not fig:
    fig = plt.figure(figsize=figsize)

  if not colors:
    clr_whl = cycle_colors()
    colors = [next(clr_whl) for _ in range(len(data))]

  concat_data = []
  for data_ind, this_data in enumerate(data):
    x = np.random.normal(1 + data_ind, 0.04, size=len(this_data))
    plt.scatter(x, this_data, c=colors[data_ind], alpha=0.6, s=5)

    concat_data.append(np.array(this_data)[:, np.newaxis])

  boxplot_data = np.hstack(concat_data)
  plt.boxplot(boxplot_data, labels=xlabels, showfliers=showfliers, whis=whis)
  plt.title(title)
  plt.ylabel(ylabel)
  return fig

def compute_barplot_edges(arr):
  quartile_one, quartile_three = np.quantile(arr, 0.25), np.quantile(arr, 0.75)
  iqr = quartile_three - quartile_one
  bottom = quartile_one - 2*iqr
  top = quartile_three + 2*iqr
  return bottom, top

def plot_scatter_barplots(data, xlabels, ylabel, title, colors=[],
                          add_boxplot=False, barplot_truncation=None,
                          showfliers=False, whis=1.5, add_scatter=True,
                          fig=None, ax=None, ax_ind=0, figsize=(20,7)):
  """Reference: https://stackoverflow.com/questions/51027717/pyplot-bar-charts-with-individual-data-points"""
  if len(xlabels) != len(data):
    raise ValueError('Data and labels must be the same length.')

  ignore_ax_ind = (ax is None)
  if fig is None or ax is None:
    fig, ax = plt.subplots(figsize=figsize)
  plotting_ax = ax if ignore_ax_ind else ax[ax_ind]

  if not colors:
    clr_whl = cycle_colors()
    colors = [next(clr_whl) for _ in range(len(data))]

  w, bar_heights = 0.5, []
  if add_boxplot:
    bar_sem, concat_data = None, []
  else:
    bar_sem = []
  for data_ind, this_data in enumerate(data):
    this_data = np.array(this_data)
    # Compute bar means and errors.
    bar_heights.append(this_data.mean())
    # https://en.wikipedia.org/wiki/Standard_error
    if not add_boxplot:
      bar_sem.append(this_data.std() / np.sqrt(this_data.size))
    else: # add_boxplot
      concat_data.append(np.array(this_data)[:, np.newaxis])

  xtick_locs = np.arange(len(bar_heights))*0.6 + 1
  plotting_ax.bar(xtick_locs, height=bar_heights,
        yerr=bar_sem, width=w, capsize=12, # bar width, error bar cap width in points
        tick_label=xlabels, color=colors, edgecolor=colors, alpha=0.5)
  if add_boxplot:
    medianprops = dict(linestyle='None')
    bplot = plotting_ax.boxplot(np.hstack(concat_data), positions=xtick_locs, labels=xlabels,
              medianprops=medianprops, patch_artist=True, showfliers=showfliers, whis=whis)
    for patch, color in zip(bplot['boxes'], colors):
      patch.set_alpha(0.5)
      patch.set_facecolor('k') # color

  plotting_ax.set_ylabel(ylabel)
  plotting_ax.set_title(title)
  if barplot_truncation is not None:
    plotting_ax.set_ylim(barplot_truncation)

  if add_scatter:
    for data_ind, this_data in enumerate(data):
      this_data = np.array(this_data)
      # Add scatterplot with jitter.
      x = np.random.normal(xtick_locs[data_ind], 0.03, size=len(this_data)) # w or 0.04
      plotting_ax.scatter(x, this_data, c='k', alpha=0.25, s=2)
  return fig, ax, xtick_locs

def cycle_colors(colors = 'rgbycmk'):
  return itertools.cycle(colors)

def push_figure_and_zoom(fig, xlim, ylim):
  fig.canvas.toolbar.push_current()  # save 'unzoomed' view to stack
  plt.gca().set_xlim(xlim)
  plt.gca().set_ylim(ylim)
  fig.canvas.toolbar.push_current()  # save 'zoomed' view to stack
  return fig

def show():
  plt.show()

def save_fig(out_path, fig_title, img_formats=[]):
  """Saves the current figure.

  Args:
    out_path: str. Path to save the figure.
    fig_title: str. Filename for figure.
    img_formats: list of str. Optional. The output image format. .png is saved
      by default. Supported image formats: 'svg', 'pdf'.
  """
  # By default save as PNG.
  plt.savefig(os.path.join(out_path, '{0}.{1}'.format(fig_title, 'png')),
      dpi=1200, facecolor='white', transparent=False, bbox_inches='tight')
  
  for img_format in img_formats:
    plt.savefig(os.path.join(out_path, '{0}.{1}'.format(fig_title, img_format)),
      format=img_format, dpi=1200, transparent=True, bbox_inches='tight')
