import os
import scipy.io
import argparse
import platform
import matplotlib
import numpy as np
import typing as t
from glob import glob
import tensorflow as tf
import matplotlib.cm as cm
import matplotlib.pyplot as plt

if platform.system() == 'Darwin':
  matplotlib.use('TkAgg')
  plt.rcParams.update({
      "text.usetex": True,
      "font.family": "serif",
      "font.serif": ["Computer Modern"]
  })

plt.style.use('seaborn-deep')

from cyclegan.utils.gradcam import gradcam
from cyclegan.models.registry import get_models
from cyclegan.alogrithms.registry import get_algorithm
from cyclegan.utils import utils, h5, tensorboard, attention_gate, dataset


def remove_top_right_spines(axis):
  """ remove the ticks and spines of the top and right axis """
  axis.spines['right'].set_visible(False)
  axis.spines['top'].set_visible(False)


def load_signals(args):
  return {
      k: h5.get(args.signals_filename, key=k)
      for k in h5.get_keys(args.signals_filename)
  }


def MAE(x: np.ndarray, y: np.ndarray):
  return np.mean(np.abs(x - y), axis=(1, 2))


def direct_comparison(data: t.Dict[str, np.ndarray],
                      key1: str,
                      key2: str,
                      augment: bool = False):
  if augment:
    size = len(data[key1])
    diagonal_mask = dataset.get_diagonal_mask(input_shape=data[key1].shape[1:])
    errors = np.zeros((size,), dtype=np.float32)
    for i in range(size):
      augmented = dataset.augment(data[key1][i], diagonal_mask=diagonal_mask)
      error = np.mean(np.abs(augmented - data[key2][i]))
      errors[i] = error
  else:
    errors = MAE(data[key1], data[key2])

  mean, std = np.mean(errors), np.std(errors)
  print(f'MAE({key1}, {key2}) = {mean:.4f} \pm {std:.4f}')


def direct_comparisons(data: t.Dict[str, np.ndarray]):
  direct_comparison(data, key1='x', key2='y')

  print('\n')
  direct_comparison(data, key1='x', key2='fake_x')
  direct_comparison(data, key1='x', key2='cycle_x')
  direct_comparison(data, key1='x', key2='same_x')
  print('\n')
  direct_comparison(data, key1='y', key2='fake_y')
  direct_comparison(data, key1='y', key2='cycle_y')
  direct_comparison(data, key1='y', key2='same_y')
  print('\n')


def plot_cycle_trace(args, filename: str, traces: t.List[t.List[np.ndarray]],
                     labels: t.List[str], title: t.List[str],
                     colors: t.List[str]):
  assert len(traces) == len(title) == 3
  figure, axes = plt.subplots(nrows=3,
                              ncols=3,
                              gridspec_kw={
                                  'wspace': 0.025,
                                  'hspace': 0.2
                              },
                              figsize=(8, 3.5),
                              sharex=True,
                              dpi=args.dpi)

  xticks_loc = np.linspace(0, len(traces[0][0]) - 1, 5)
  xticks = (xticks_loc / args.frame_rate).astype(int)
  fontsize = 10

  for i in range(len(traces)):
    trace1, trace2, trace3 = traces[i]
    # set y-axis limit for the first and last trace
    n_min = np.min([trace1, trace2, trace3])
    n_max = np.max([trace1, trace2, trace3])

    axes[i][0].plot(trace1, color=colors[0], linewidth=1, alpha=0.8)

    axes[i][1].plot(trace2, color=colors[1], linewidth=1, alpha=0.8)

    axes[i][2].plot(trace3, color=colors[0], linewidth=1, alpha=0.8)

    yticks_loc = np.linspace(n_min, n_max, 3)
    tensorboard.set_yticks(axis=axes[i, 0],
                           ticks_loc=yticks_loc,
                           ticks=np.around(yticks_loc, 2),
                           label=r'$\Delta F/F$' if i == 1 else '',
                           fontsize=fontsize)
    axes[i, 1].set_yticks([])
    axes[i, 2].set_yticks([])

    # add neuron number on the right side
    ax2 = axes[i][2].twinx()
    ax2.set_ylabel(title[i], rotation=270, va='bottom')
    ax2.set_yticks([])
    tensorboard.remove_spines(axis=ax2)

    for j in range(3):
      remove_top_right_spines(axis=axes[i][j])

    if i == 0:
      axes[i][0].set_title(labels[0])
      axes[i][1].set_title(labels[1])
      axes[i][2].set_title(labels[2])

    if i == 2:
      tensorboard.set_xticks(axis=axes[i, 0],
                             ticks_loc=xticks_loc,
                             ticks=xticks,
                             fontsize=fontsize)
      tensorboard.set_xticks(axis=axes[i, 1],
                             ticks_loc=xticks_loc,
                             ticks=xticks,
                             label='Time (s)',
                             fontsize=fontsize)
      tensorboard.set_xticks(axis=axes[i, 2],
                             ticks_loc=xticks_loc,
                             ticks=xticks,
                             fontsize=fontsize)

    axes[i, 0].tick_params(axis='both', which='both', length=0)
    axes[i, 1].tick_params(axis='both', which='both', length=0)
    axes[i, 2].tick_params(axis='both', which='both', length=0)

  filename = os.path.join(args.plot_dir, filename)
  figure.savefig(filename, dpi=args.dpi, bbox_inches='tight', pad_inches=0.02)
  print(f'saved {filename}')


def plot_specify_neuron(args, filename: str, forward: t.List, backward: t.List):
  assert len(forward) == len(backward)

  nrows, ncols = 2, 3
  figure, axes = plt.subplots(nrows=2,
                              ncols=3,
                              gridspec_kw={
                                  'wspace': 0.025,
                                  'hspace': 0.5
                              },
                              figsize=(8, 2),
                              sharex=True,
                              dpi=args.dpi)

  linewidth, alpha, fontsize = 1, 0.8, 10

  f_min, f_max = np.min(forward), np.max(forward)

  axes[0, 0].plot(forward[0],
                  color='dodgerblue',
                  linewidth=linewidth,
                  alpha=alpha)
  axes[0, 1].plot(forward[1],
                  color='orangered',
                  linewidth=linewidth,
                  alpha=alpha)
  axes[0, 2].plot(forward[2],
                  color='dodgerblue',
                  linewidth=linewidth,
                  alpha=alpha)

  yticks_loc = np.linspace(f_min, f_max, 3)
  tensorboard.set_yticks(axis=axes[0, 0],
                         ticks_loc=yticks_loc,
                         ticks=np.around(yticks_loc, 2),
                         fontsize=fontsize)
  axes[0, 1].set_yticks([])
  axes[0, 2].set_yticks([])

  axes[0, 0].set_title(r'x', va='center')
  axes[0, 1].set_title(r'G(x)', va='center')
  axes[0, 2].set_title(r'F(G(x))', va='center')

  b_min, b_max = np.min(backward), np.max(backward)

  axes[1, 0].plot(backward[0],
                  color='orangered',
                  linewidth=linewidth,
                  alpha=alpha)
  axes[1, 1].plot(backward[1],
                  color='dodgerblue',
                  linewidth=linewidth,
                  alpha=alpha)
  axes[1, 2].plot(backward[2],
                  color='orangered',
                  linewidth=linewidth,
                  alpha=alpha)

  yticks_loc = np.linspace(b_min, b_max, 3)
  tensorboard.set_yticks(axis=axes[1, 0],
                         ticks_loc=yticks_loc,
                         ticks=np.around(yticks_loc, 2),
                         fontsize=fontsize)
  axes[1, 1].set_yticks([])
  axes[1, 2].set_yticks([])

  axes[1, 0].set_title(r'y', va='center')
  axes[1, 1].set_title(r'F(y)', va='center')
  axes[1, 2].set_title(r'G(F(y))', va='center')

  xticks_loc = np.linspace(0, len(forward[0]) - 1, 5)
  xticks = (xticks_loc / args.frame_rate).astype(int)
  tensorboard.set_xticks(axis=axes[1, 0],
                         ticks_loc=xticks_loc,
                         ticks=xticks,
                         fontsize=fontsize)
  tensorboard.set_xticks(axis=axes[1, 1],
                         ticks_loc=xticks_loc,
                         ticks=xticks,
                         label='Time (s)',
                         fontsize=fontsize)
  tensorboard.set_xticks(axis=axes[1, 2],
                         ticks_loc=xticks_loc,
                         ticks=xticks,
                         fontsize=fontsize)

  for i in range(nrows):
    for j in range(ncols):
      remove_top_right_spines(axis=axes[i, j])
      axes[i, j].tick_params(axis='both', which='both', length=0)

  axes[0, 0].annotate(r'$\Delta F/F$',
                      xy=(0.03, 0.5),
                      xycoords='figure fraction',
                      rotation=90,
                      fontsize=10)

  filename = os.path.join(args.plot_dir, filename)
  figure.savefig(filename, dpi=args.dpi, bbox_inches='tight', pad_inches=0.02)
  print(f'saved {filename}')


def plot_cycle_traces(args, data: t.Dict[str, np.ndarray], trial: int):
  colors = ['dodgerblue', 'orangered']
  neuron_indexes = np.linspace(0, args.num_neurons - 1, 5, dtype=int)
  neurons = neuron_indexes[1:-1]

  forward_cycle_traces = [[
      data['x'][trial, :, n], data['fake_y'][trial, :, n],
      data['cycle_x'][trial, :, n]
  ] for n in neurons]
  plot_cycle_trace(
      args,
      filename=f'forward_cycle_traces.{args.format}',
      traces=forward_cycle_traces,
      labels=[r'x', r'G(x)', r'F(G(x))'],
      title=[f'{neurons[0]}', f'Neuron\n{neurons[1]}', f'{neurons[2]}'],
      colors=colors)

  backward_cycle_traces = [[
      data['y'][trial, :, n], data['fake_x'][trial, :, n],
      data['cycle_y'][trial, :, n]
  ] for n in neurons]
  plot_cycle_trace(
      args,
      filename=f'backward_cycle_traces.{args.format}',
      traces=backward_cycle_traces,
      labels=[r'y', r'F(y)', r'G(F(y))'],
      title=[f'{neurons[0]}', f'Neuron\n{neurons[1]}', f'{neurons[2]}'],
      colors=colors[::-1])

  plot_specify_neuron(args,
                      filename=f'neuron75_cycle.{args.format}',
                      forward=forward_cycle_traces[2],
                      backward=backward_cycle_traces[2])


def plot_population_cycle(args, filename: str, traces: t.List[np.ndarray],
                          titles: t.List[str]):
  for i in range(len(traces)):
    ds_min, ds_max = np.min(traces[i]), np.max(traces[i])
    trace = utils.scale(traces[i], ds_min, ds_max)
    traces[i] = np.swapaxes(trace, axis1=0, axis2=1)

  yticks_loc = np.linspace(0, traces[0].shape[0] - 1, 5)
  yticks = yticks_loc.astype(int) + 1

  figure, axes = plt.subplots(nrows=3,
                              ncols=1,
                              gridspec_kw={
                                  'wspace': 0.15,
                                  'hspace': 0.1
                              },
                              sharex=True,
                              figsize=(6.5, 6),
                              dpi=args.dpi)

  for i in range(3):
    axes[i].imshow(np.flip(traces[i], axis=0), cmap='gray', aspect='auto')
    tensorboard.set_yticks(axis=axes[i],
                           ticks_loc=yticks_loc,
                           ticks=yticks,
                           label='Neuron' if i == 1 else '')
    tensorboard.set_right_label(axis=axes[i], label=titles[i])
    axes[i].tick_params(axis='both', which='both', length=0)

  xticks_loc = np.linspace(0, traces[0].shape[1] - 1, 5)
  tensorboard.set_xticks(
      axis=axes[2],
      ticks_loc=xticks_loc,
      ticks=(xticks_loc / args.frame_rate).astype(int),
      label='Time (s)',
  )

  filename = os.path.join(args.plot_dir, filename)
  figure.savefig(filename, dpi=args.dpi, bbox_inches='tight', pad_inches=0.02)
  print(f'saved {filename}')


def plot_population_cycles(args, data: t.Dict[str, np.ndarray], trial: int):
  plot_population_cycle(args,
                        filename='forward_cycle_population.pdf',
                        traces=[
                            data['x'][trial, ...], data['fake_y'][trial, ...],
                            data['cycle_x'][trial, ...]
                        ],
                        titles=[r'x', r'G(x)', r'F(G(x))'])
  plot_population_cycle(args,
                        filename='backward_cycle_population.pdf',
                        traces=[
                            data['y'][trial, ...], data['fake_x'][trial, ...],
                            data['cycle_y'][trial, ...]
                        ],
                        titles=[r'y', r'F(y)', r'G(F(y))'])


def main(args):
  np.random.seed(777)

  if not os.path.exists(args.output_dir):
    raise FileNotFoundError(f'{args.output_dir} not found.')

  print(f'loading {args.output_dir}...\n')

  args.plot_dir = os.path.join(args.output_dir, 'plots')
  args.samples_dir = os.path.join(args.output_dir, 'samples')
  args.signals_filename = os.path.join(args.samples_dir, 'signals.h5')
  args.checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
  utils.load_args(args)

  if not os.path.exists(args.plot_dir):
    os.makedirs(args.plots_dir)

  data = load_signals(args)
  direct_comparisons(data)

  trial = np.random.choice(len(data['x']))
  plot_cycle_traces(args, data=data, trial=trial)
  plot_population_cycles(args, data=data, trial=trial)

  print(f'\nFigures saved at {args.plot_dir}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--output_dir', type=str, required=True)
  parser.add_argument('--format',
                      type=str,
                      default='pdf',
                      choices=['pdf', 'png', 'svg'])
  parser.add_argument('--dpi', type=int, default=120)
  main(parser.parse_args())
