import os
import io
import math
import platform
import matplotlib
import numpy as np
import typing as t
import seaborn as sns
import tensorflow as tf
from shutil import rmtree
import matplotlib.cm as cm
import matplotlib.pyplot as plt

from cyclegan.utils import utils

JET = cm.get_cmap('jet')
COLORMAP = JET(np.arange(256))[:, :3]


def normalize(x: np.ndarray):
  """ normalize  array x to [0, 1] """
  return np.maximum(x, 0) / np.max(x)


def remove_top_right_spines(axis):
  """ remove the ticks and spines of the top and right axis """
  axis.spines['top'].set_visible(False)
  axis.spines['right'].set_visible(False)


def remove_spines(axis):
  axis.spines['top'].set_visible(False)
  axis.spines['left'].set_visible(False)
  axis.spines['right'].set_visible(False)
  axis.spines['bottom'].set_visible(False)


def plot_reward_zones(axis,
                      reward_ranges: t.List[t.Tuple[int, int]],
                      alpha: float = 0.8):
  for (reward_start, reward_end) in reward_ranges:
    axis.axvline(x=reward_start,
                 alpha=alpha,
                 color='yellow',
                 linewidth=1,
                 linestyle='--')
    axis.axvline(x=reward_end,
                 alpha=alpha,
                 color='orange',
                 linewidth=1,
                 linestyle='--')


def plot_licks(axis, licks: t.Union[t.List[int], np.ndarray]):
  for lick in licks:
    axis.axvline(x=lick,
                 alpha=0.8,
                 color='limegreen',
                 linewidth=0.7,
                 linestyle='--')


def plot_rewards(axis, rewards: t.Union[t.List[int], np.ndarray]):
  for reward in rewards:
    axis.axvline(x=reward,
                 alpha=0.8,
                 color='orangered',
                 linewidth=0.7,
                 linestyle='--')


def gray2rgb(x: t.Union[np.ndarray, tf.Tensor]):
  ''' convert x from gray scale to RBG '''
  if tf.is_tensor(x):
    x = x.numpy()
  return COLORMAP[np.uint8(255.0 * x)]


def resize(x: np.ndarray, height: int, width: int):
  ''' resize array x to (height, width) using PIL.resize '''
  expand_dim = len(x.shape) == 2
  if expand_dim:
    x = np.expand_dims(x, axis=-1)
  x = tf.keras.preprocessing.image.array_to_img(x)
  x = x.resize(size=(width, height))
  x = tf.keras.preprocessing.image.img_to_array(x)
  x /= 255.0
  if expand_dim:
    x = np.squeeze(x, axis=-1)
  return x


def superimpose(heatmap: np.ndarray,
                background: np.ndarray,
                alpha: float = 0.4):
  ''' overlay heatmap onto background '''
  assert len(heatmap.shape) == 3 and heatmap.shape[:2] == background.shape[:2]
  if len(background.shape) == 2:
    background = np.expand_dims(background, axis=-1)
  output = alpha * heatmap + (1 - alpha) * background
  return normalize(output)


def remove_ticks(axis):
  ''' remove x and y ticks in axis'''
  axis.set_xticks([])
  axis.set_yticks([])


def set_xticks(axis,
               ticks_loc: t.Union[np.ndarray, list],
               ticks: t.Union[np.ndarray, list],
               label: str = '',
               fontsize: int = None):
  axis.set_xticks(ticks_loc)
  axis.set_xticklabels(ticks, fontsize=fontsize)
  if label:
    axis.set_xlabel(label)


def set_yticks(axis,
               ticks_loc: t.Union[np.ndarray, list],
               ticks: t.Union[np.ndarray, list],
               label: str = '',
               fontsize: int = None):
  axis.set_yticks(ticks_loc)
  axis.set_yticklabels(ticks, fontsize=fontsize)
  if label:
    axis.set_ylabel(label)


def set_right_label(axis, label: str):
  right_axis = axis.twinx()
  right_axis.set_ylabel(label, rotation=270, va='bottom')
  right_axis.set_yticks([])
  remove_top_right_spines(right_axis)


def save_figure(figure: plt.Figure,
                filename: str,
                dpi: int,
                close: bool = True):
  if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
  figure.savefig(filename,
                 dpi=dpi,
                 bbox_inches='tight',
                 pad_inches=0.01,
                 transparent=True)
  if close:
    plt.close(figure)


class Summary(object):
  """ Helper class to write TensorBoard summaries """

  def __init__(self, args, analysis: bool = False):
    self.dpi = args.dpi
    self.dataset = args.dataset

    if self.dataset != 'horse2zebra':
      self.frame_rate = args.frame_rate
      self.num_neurons = args.num_neurons
      self.neuron_order = args.neuron_order
      # select 3 neurons that are evenly spaced to plot
      self.neuron_indexes = np.linspace(0, self.num_neurons - 1, 5, dtype=int)
      self.neuron_indexes = self.neuron_indexes[1:-1]

    # settings for saving plots to disk
    self.format = args.format
    self.save_plots = args.save_plots

    # create plot directory
    self.plots_dir = os.path.join(args.output_dir, 'plots')
    if not os.path.exists(self.plots_dir):
      os.makedirs(self.plots_dir)

    if platform.system() == 'Darwin':
      plt.rcParams.update({
          "text.usetex": True,
          "font.family": "serif",
          "font.serif": ["Computer Modern"]
      })
      if args.verbose == 2:
        matplotlib.use('TkAgg')

    sns.set_style('white')
    plt.style.use('seaborn-deep')

    # color palette to use
    self.colors = [
        'dodgerblue', 'orangered', 'forestgreen', 'darkviolet', 'dimgray',
        'magenta'
    ]

    # TensorBoard file writers
    self.analysis = analysis
    if analysis:
      self.writer = tf.summary.create_file_writer(
          os.path.join(args.output_dir, 'analysis'))
    else:
      self.writer = [
          tf.summary.create_file_writer(args.output_dir),
          tf.summary.create_file_writer(
              os.path.join(args.output_dir, 'validation'))
      ]

  def close(self):
    """ flush and close all summary writers """
    if type(self.writer) == list:
      for writer in self.writer:
        writer.close()
    else:
      self.writer.close()

  def get_writer(self, mode: int = 0):
    return self.writer if self.analysis else self.writer[mode]

  def scalar(self, tag, value, step=0, mode: int = 0):
    writer = self.get_writer(mode)
    with writer.as_default():
      tf.summary.scalar(tag, value, step=step)

  def histogram(self, tag, values, step=0, mode: int = 0):
    writer = self.get_writer(mode)
    with writer.as_default():
      tf.summary.histogram(tag, values, step=step)

  def image(self, tag, values, step=0, mode: int = 0):
    writer = self.get_writer(mode)
    with writer.as_default():
      tf.summary.image(tag, data=values, step=step, max_outputs=len(values))

  def figure(self, tag, figure, step=0, close: bool = True, mode: int = 0):
    """ Write matplotlib figure to summary
    Args:
      tag: data identifier
      figure: matplotlib figure or a list of figures
      step: global step value to record
      close: flag to close figure
      training: training summary or validation summary
    """
    if self.save_plots:
      save_figure(figure,
                  filename=os.path.join(
                      self.plots_dir,
                      'analysis' if self.analysis else f'epoch_{step:03d}',
                      f'{tag}.{self.format}'),
                  dpi=self.dpi,
                  close=False)
    buffer = io.BytesIO()
    figure.savefig(buffer,
                   dpi=self.dpi,
                   format='png',
                   bbox_inches='tight',
                   pad_inches=0.02)
    buffer.seek(0)
    image = tf.image.decode_png(buffer.getvalue(), channels=4)
    self.image(tag, tf.expand_dims(image, 0), step=step, mode=mode)
    if close:
      plt.close(figure)

  def plot_comparison(self,
                      tag: str,
                      traces: t.Union[t.List[tf.Tensor], t.List[np.ndarray]],
                      labels: t.List[str],
                      step: int = 0,
                      mode: int = 0):
    assert len(traces) == len(labels)
    assert len(traces[0].shape) == 2
    # convert tf.Tensor to numpy array
    if tf.is_tensor(traces[0]):
      traces = [trace.numpy() for trace in traces]

    for neuron in self.neuron_indexes:
      n_traces = [trace[:, neuron] for trace in traces]
      figure, axes = plt.subplots(nrows=1,
                                  ncols=len(traces),
                                  gridspec_kw={
                                      'wspace': 0.1,
                                      'hspace': 0.5
                                  },
                                  figsize=(4 * len(traces), 2),
                                  dpi=self.dpi)
      figure.suptitle(f'Neuron {neuron}')

      n_min, n_max = np.min(n_traces), np.max(n_traces)
      ylim = (n_min - np.abs(n_min * 0.15), n_max + np.abs(n_max * 0.15))

      xticks_loc = np.linspace(0, len(n_traces[0]) - 1, 5)
      for i, trace in enumerate(n_traces):
        axes[i].plot(trace, color='dodgerblue', linewidth=1.5, alpha=0.8)
        axes[i].set_title(labels[i])
        if i == 0:
          axes[i].set_ylabel(r'$\Delta F/F$')
        axes[i].set_ylim(ylim)
        remove_top_right_spines(axis=axes[i])
        axes[i].tick_params(axis='both', which='both', length=1)
        set_xticks(axis=axes[i],
                   ticks_loc=xticks_loc,
                   ticks=(xticks_loc / self.frame_rate).astype(int),
                   label='Time (s)')

      self.figure(f'{tag}/neuron_{neuron:03d}',
                  figure=figure,
                  step=step,
                  close=True,
                  mode=mode)

  def image_cycle(self,
                  tag: str,
                  images: t.List[np.ndarray],
                  labels: t.List[str],
                  step: int = 0,
                  mode: int = 0):
    for sample in range(len(images[0])):
      figure, axes = plt.subplots(nrows=1,
                                  ncols=3,
                                  figsize=(9, 3.25),
                                  dpi=self.dpi)
      axes[0].imshow(images[0][sample, ...], interpolation='none')
      axes[0].set_title(labels[0])

      axes[1].imshow(images[1][sample, ...], interpolation='none')
      axes[1].set_title(labels[1])

      axes[2].imshow(images[2][sample, ...], interpolation='none')
      axes[2].set_title(labels[2])

      plt.setp(axes, xticks=[], yticks=[])
      plt.tight_layout()
      figure.subplots_adjust(wspace=0.02, hspace=0.02)
      self.figure(f'{tag}/sample_#{sample:03d}',
                  figure,
                  step=step,
                  close=True,
                  mode=mode)

  def trace_cycle(self,
                  tag: str,
                  samples: t.List[np.ndarray],
                  titles: t.List[str],
                  colors: t.List[str],
                  step: int = 0,
                  mode: int = 0):
    ''' plot cycle traces of self.neuron_index neurons
    Args
      tag: str, name of the plot in TensorBoard
      samples: list of arrays, corresponds to 1st, 2nd and 3rd column
      titles: list of str, the titles of the 1st, 2nd and 3rd column
      colors: list of str, the color of the signal
              colors[0] corresponds to the 1st and 3rd columns
              colors[1] corresponds to the 2nd column
      trial: int, the index of the trial to plot
      step: int, TensorBoard step
      mode: int, file writer mode
    '''
    assert len(samples) == len(titles) == 3 and len(colors) == 2
    shape = samples[0].shape

    fontsize, linewidth, alpha = 11, 0.8, 0.8
    xticks_loc = np.linspace(0, shape[1] - 1, 5)
    xticks = (xticks_loc / self.frame_rate).astype(int)

    for t in range(shape[0]):
      figure, axes = plt.subplots(nrows=len(self.neuron_indexes),
                                  ncols=3,
                                  gridspec_kw={
                                      'wspace': 0.2,
                                      'hspace': 0.2
                                  },
                                  figsize=(8, 1.2 * len(self.neuron_indexes)),
                                  sharex=True,
                                  dpi=self.dpi)

      for i, n in enumerate(self.neuron_indexes):
        trace1 = samples[0][t, :, n]
        trace2 = samples[1][t, :, n]
        trace3 = samples[2][t, :, n]

        axes[i][0].plot(trace1,
                        color=colors[0],
                        linewidth=linewidth,
                        alpha=alpha)
        axes[i][1].plot(trace2,
                        color=colors[1],
                        linewidth=linewidth,
                        alpha=alpha)
        axes[i][2].plot(trace3,
                        color=colors[0],
                        linewidth=linewidth,
                        alpha=alpha)

        set_right_label(axis=axes[i][2],
                        label=f'Neuron {self.neuron_order[n]+1}')

        yticks_loc = np.linspace(np.min(trace2), np.max(trace2), 4)
        set_yticks(axis=axes[i][1],
                   ticks_loc=yticks_loc,
                   ticks=np.around(yticks_loc, 1),
                   fontsize=fontsize)

        # get the y-axis limit for the 1st and 3rd column
        t_min = min(np.min(trace1), np.min(trace3))
        t_max = max(np.max(trace1), np.max(trace3))
        yticks_loc = np.linspace(t_min, t_max, 4)
        yticks = np.around(yticks_loc, 1)
        set_yticks(axis=axes[i][0],
                   ticks_loc=yticks_loc,
                   ticks=yticks,
                   fontsize=fontsize)
        set_yticks(axis=axes[i][2],
                   ticks_loc=yticks_loc,
                   ticks=yticks,
                   fontsize=fontsize)

        if i == 1:
          axes[i][0].set_ylabel(r'$\Delta F/F$')

        for j in range(3):
          if i == 0:
            axes[i][j].set_title(titles[j])
          elif i == len(self.neuron_indexes) - 1:
            set_xticks(axis=axes[i][j],
                       ticks_loc=xticks_loc,
                       ticks=xticks,
                       label='Time (s)' if j == 1 else '',
                       fontsize=fontsize)
          axes[i][j].tick_params(axis='both', which='both', length=0)
          remove_top_right_spines(axis=axes[i][j])

      self.figure(tag=f'{tag}/trial_{t:03d}',
                  figure=figure,
                  step=step,
                  mode=mode)

  def population_cycle(self,
                       tag: str,
                       samples: t.List[np.ndarray],
                       titles: t.List[str],
                       info: t.Dict[str, np.ndarray] = None,
                       step: int = 0,
                       mode: int = 0):
    ''' plot the cycle step of the whole population
    Args
      tag: str, name of the plot in TensorBoard
      samples: list of arrays, corresponds to the 1st, 2nd and 3rd row
      titles: list of str, the titles of the 1st, 2nd and 3rd row
      info: dictionary of np.ndarray, trial information
      step: int, TensorBoard step
      mode: int, file writer mode
    '''
    assert len(samples) == len(titles) == 3
    shape = samples[0].shape

    # normalize x to [0, 1] and swap axes of the first two dimensions
    scale_and_swap = lambda x: np.swapaxes(
        utils.scale(x, ds_min=np.min(x), ds_max=np.max(x)), axis1=0, axis2=1)

    fontsize = 11
    xticks_loc = np.linspace(0, shape[1] - 1, 6)
    xticks = (xticks_loc / self.frame_rate).astype(int)
    yticks_loc = np.linspace(0, shape[2] - 1, 6)
    yticks = self.neuron_order[np.linspace(0, shape[2] - 1, 6).astype(int)] + 1

    for t in range(shape[0]):
      figure, axes = plt.subplots(nrows=3,
                                  ncols=1,
                                  gridspec_kw={
                                      'wspace': 0.15,
                                      'hspace': 0.1
                                  },
                                  figsize=(6, 6),
                                  sharex=True,
                                  dpi=self.dpi)
      populations = [scale_and_swap(sample[t]) for sample in samples]

      for i in range(len(populations)):
        axes[i].imshow(populations[i], cmap='gray', aspect='auto')
        set_yticks(axis=axes[i],
                   ticks_loc=yticks_loc,
                   ticks=yticks,
                   label='Neuron' if i == 1 else '',
                   fontsize=fontsize)
        set_right_label(axis=axes[i], label=titles[i])
        axes[i].tick_params(axis='both', which='both', length=0)

        set_xticks(axis=axes[2],
                   ticks_loc=xticks_loc,
                   ticks=xticks,
                   label='Time (s)',
                   fontsize=fontsize)

      # draw reward information
      if info is not None:
        plot_reward_zones(axes[0],
                          reward_ranges=utils.get_reward_ranges(
                              info['reward_zone'][t]),
                          alpha=0.7)

      self.figure(tag=f'{tag}/trial_{t:03d}',
                  figure=figure,
                  step=step,
                  mode=mode)

  def plot_histogram(self,
                     tag: str,
                     data: t.Dict[str, np.ndarray],
                     xlabel: str = '',
                     ylabel: str = '',
                     title: str = None,
                     legend: bool = True,
                     step: int = 0,
                     mode: int = 0,
                     annotate_statistic: bool = False):
    figure, axis = plt.subplots(nrows=1,
                                ncols=1,
                                gridspec_kw={
                                    'hspace': 0.01,
                                    'wspace': 0.01
                                },
                                figsize=(5, 5),
                                dpi=self.dpi)
    figure.patch.set_facecolor('white')

    values = list(data.values())
    max_height, min_width, max_width = 0, math.inf, 0
    bins = np.linspace(np.min(values), np.max(values), 30)
    for i, (k, v) in enumerate(data.items()):
      heights, widths, _ = axis.hist(
          v,
          bins=bins,
          color=self.colors[i] if len(data) > 1 else 'darkgreen',
          label=k,
          alpha=0.6)
      if np.max(heights) > max_height:
        max_height = np.max(heights)
      if np.min(widths) < min_width:
        min_width = np.min(widths)
      if np.max(widths) > max_width:
        max_width = np.max(widths)

    xticks_loc = np.linspace(min_width, max_width, 5)
    set_xticks(axis=axis,
               ticks_loc=xticks_loc,
               ticks=np.round(xticks_loc, decimals=1),
               label=xlabel)
    yticks_loc = np.linspace(0, math.ceil(max_height), 5)
    set_yticks(axis=axis,
               ticks_loc=yticks_loc,
               ticks=yticks_loc.astype(int),
               label=ylabel)
    axis.tick_params(axis='both', which='both', length=1)
    remove_top_right_spines(axis=axis)
    if legend and len(data) > 1:
      axis.legend(frameon=False)
    if title is not None:
      axis.set_title(title)

    # show the mean and standard deviation if there is only one distribution
    if annotate_statistic and len(data) == 1:
      mean, std = np.mean(values), np.std(values)
      plt.annotate(f'$\mu = {mean:.4f}$\n$\sigma = {std:.4f}$',
                   xy=(0.75, 0.9),
                   xycoords='axes fraction',
                   fontsize=10)

    figure.tight_layout()
    self.figure(tag, figure=figure, step=step, close=True, mode=mode)

  def plot_heatmap(self,
                   tag: str,
                   matrix: np.ndarray,
                   xlabel: str,
                   ylabel: str,
                   xticklabels: t.Union[np.ndarray, list],
                   yticklabels: t.Union[np.ndarray, list],
                   title: str = '',
                   step: int = 0,
                   mode: int = 0):
    figure, axes = plt.subplots(nrows=1,
                                ncols=2,
                                gridspec_kw={
                                    'width_ratios': [1, 0.03],
                                    'hspace': 0.01,
                                    'wspace': 0.01
                                },
                                figsize=(5.2, 5),
                                dpi=self.dpi)
    figure.patch.set_facecolor('white')

    axes[0].imshow(matrix, cmap='YlOrRd', aspect='equal', interpolation='none')

    fontsize = 10
    ticks_loc = np.linspace(0, len(xticklabels) - 1, 12)
    set_xticks(
        axis=axes[0],
        ticks_loc=ticks_loc,
        ticks=xticklabels[ticks_loc.astype(int)],
        label=xlabel,
        fontsize=fontsize,
    )
    set_yticks(
        axis=axes[0],
        ticks_loc=ticks_loc,
        ticks=yticklabels[ticks_loc.astype(int)],
        label=ylabel,
        fontsize=fontsize,
    )
    axes[0].tick_params(axis='both', which='both', length=1)

    figure.colorbar(cm.ScalarMappable(cmap='YlOrRd'), cax=axes[1])
    cbar_ticks_loc = np.linspace(np.min(matrix), np.max(matrix), 6)
    set_yticks(axis=axes[1],
               ticks_loc=cbar_ticks_loc,
               ticks=cbar_ticks_loc.astype(int),
               label='',
               fontsize=fontsize)
    axes[1].tick_params(axis='both', which='both', length=0)

    if title:
      axes[0].set_title(title)

    figure.tight_layout()
    self.figure(tag, figure=figure, step=step, close=True, mode=mode)

  def raster_plot(self,
                  tag: str,
                  spikes1,
                  spikes2,
                  xlabel: str = '',
                  ylabel: str = '',
                  legends: t.List[str] = None,
                  step: int = 0,
                  mode: int = 0,
                  yticks_loc: t.Union[np.ndarray, t.List[str]] = None,
                  yticks: t.Union[np.ndarray, t.List[str]] = None):
    assert spikes1.shape == spikes2.shape
    if spikes1.shape[0] == self.num_neurons:
      # spikes should be in shape (WC)
      spikes1, spikes2 = spikes1.T, spikes2.T
    spikes1_x, spikes1_y = np.nonzero(spikes1)
    spikes2_x, spikes2_y = np.nonzero(spikes2)

    figure, axes = plt.subplots(nrows=2,
                                ncols=2,
                                gridspec_kw={
                                    'width_ratios': [1, 0.06],
                                    'height_ratios': [0.1, 1],
                                    'wspace': 0.0025,
                                    'hspace': 0.0
                                },
                                figsize=(8, 6),
                                dpi=self.dpi)

    # remove ticks and spines for histograms
    remove_spines(axis=axes[0, 0])
    remove_spines(axis=axes[0, 1])
    remove_spines(axis=axes[1, 1])
    remove_ticks(axis=axes[0, 0])
    remove_ticks(axis=axes[0, 1])
    remove_ticks(axis=axes[1, 1])

    axes[1, 0].scatter(x=spikes1_x,
                       y=spikes1_y,
                       c='dodgerblue',
                       marker='|',
                       alpha=0.6,
                       edgecolors=None,
                       label=legends[0])
    axes[1, 0].scatter(x=spikes2_x,
                       y=spikes2_y,
                       c='orangered',
                       marker='|',
                       alpha=0.6,
                       edgecolors=None,
                       label=legends[0])
    xticks_loc = np.linspace(0, spikes1.shape[0] - 1, 6)
    set_xticks(axis=axes[1, 0],
               ticks_loc=xticks_loc,
               ticks=(xticks_loc / self.frame_rate).astype(int),
               label=xlabel)
    if yticks_loc is None:
      yticks_loc = np.linspace(0, spikes1.shape[1] - 1, 6)
      yticks = self.neuron_order[yticks_loc.astype(int)][::-1] + 1
    set_yticks(axis=axes[1, 0],
               ticks_loc=yticks_loc,
               ticks=yticks,
               label=ylabel)
    axes[1, 0].tick_params(axis='both', which='both', length=1)

    # show legend
    axes[1, 0].legend(labels=legends,
                      loc=(0, -0.10),
                      ncol=2,
                      prop={'weight': 'regular'},
                      frameon=False,
                      framealpha=1,
                      handletextpad=-0.5,
                      columnspacing=1.0)

    # plot temporal histograms
    x_bins = np.linspace(min(np.min(spikes1_x), np.min(spikes2_x)),
                         max(np.max(spikes1_x), np.max(spikes2_x)), 30)
    axes[0, 0].hist(spikes1_x, bins=x_bins, color='dodgerblue', alpha=0.6)
    axes[0, 0].hist(spikes2_x, bins=x_bins, color='orangered', alpha=0.6)

    # plot spatio histograms
    y_bins = np.linspace(min(np.min(spikes1_y), np.min(spikes2_y)),
                         max(np.max(spikes1_y), np.max(spikes2_y)), 25)
    axes[1, 1].hist(spikes1_y,
                    bins=y_bins,
                    orientation='horizontal',
                    color='dodgerblue',
                    alpha=0.6)
    axes[1, 1].hist(spikes2_y,
                    bins=y_bins,
                    orientation='horizontal',
                    color='orangered',
                    alpha=0.6)

    self.figure(tag, figure=figure, step=step, close=True, mode=mode)
