import os
import h5py
import argparse
import platform
import matplotlib
import numpy as np
import typing as t
import tensorflow as tf
from matplotlib import cm
from functools import partial
import matplotlib.pyplot as plt

from cyclegan.utils import dataset, tensorboard

if platform.system() == 'Darwin':
  plt.rcParams.update({
      "text.usetex": True,
      "font.family": "serif",
      "font.serif": ["Computer Modern"]
  })


def parse_segment(example, input_shape: t.Tuple[int, int]):
  example = tf.io.parse_single_example(
      example,
      features={
          'signal': tf.io.FixedLenFeature([], tf.string),
          'reward': tf.io.FixedLenFeature([], tf.string),
          'lick': tf.io.FixedLenFeature([], tf.string),
          'reward_zone': tf.io.FixedLenFeature([], tf.string),
          'position': tf.io.FixedLenFeature([], tf.string),
          'trial': tf.io.FixedLenFeature([], tf.string)
      })
  signal = tf.io.decode_raw(example['signal'], out_type=tf.float32)
  signal = tf.reshape(signal, shape=input_shape)
  reward = tf.io.decode_raw(example['reward'], out_type=tf.int16)
  reward = tf.reshape(reward, shape=(input_shape[0],))
  lick = tf.io.decode_raw(example['lick'], out_type=tf.int16)
  lick = tf.reshape(lick, shape=(input_shape[0],))
  reward_zone = tf.io.decode_raw(example['reward_zone'], out_type=tf.int16)
  reward_zone = tf.reshape(reward_zone, shape=(input_shape[0],))
  position = tf.io.decode_raw(example['position'], out_type=tf.float32)
  position = tf.reshape(position, shape=(input_shape[0],))
  trial = tf.io.decode_raw(example['trial'], out_type=tf.int16)
  trial = tf.reshape(trial, shape=(input_shape[0],))
  return {
      'signal': signal,
      'reward': reward,
      'lick': lick,
      'reward_zone': reward_zone,
      'position': position,
      'trial': trial
  }


def get_ds(args):
  filenames = dataset.get_info(args)

  x_files = tf.data.Dataset.list_files(filenames['x_sample'])
  x_ds = x_files.interleave(tf.data.TFRecordDataset)
  x_ds = x_ds.map(partial(parse_segment, input_shape=args.input_shape))
  x_ds = x_ds.batch(1)

  y_files = tf.data.Dataset.list_files(filenames['y_sample'])
  y_ds = y_files.interleave(tf.data.TFRecordDataset)
  y_ds = y_ds.map(partial(parse_segment, input_shape=args.input_shape))
  y_ds = y_ds.batch(1)

  return x_ds, y_ds


def plot_augmentation(args, ds):
  sample = next(iter(ds))
  signal = sample['signal'].numpy()[0]

  diagonal_mask = dataset.get_diagonal_mask(input_shape=args.input_shape)
  augmented_signal = dataset.augment(signal,
                                     diagonal_mask=diagonal_mask,
                                     ds_min=args.ds_min,
                                     ds_max=args.ds_max)

  signal = np.transpose(signal, axes=[1, 0])
  augmented_signal = np.transpose(augmented_signal, axes=[1, 0])

  figure, axes = plt.subplots(nrows=1,
                              ncols=3,
                              gridspec_kw={
                                  'wspace': 0.025,
                                  'hspace': 0.1,
                                  'width_ratios': [1, 1, 0.04]
                              },
                              figsize=(8.5, 1.5),
                              dpi=args.dpi)

  fontsize = 11
  axes[0].imshow(signal / np.max(signal), cmap='gray', aspect='auto')
  axes[0].set_ylabel('Neuron', fontsize=fontsize, labelpad=0)
  axes[0].set_xlabel('Time (s)', fontsize=fontsize, labelpad=0)
  # axes[0].set_title(r'$x$ ', fontsize=fontsize, pad=4)

  axes[1].imshow(augmented_signal, cmap='gray', aspect='auto')
  axes[1].set_xlabel('Time (s)', fontsize=fontsize, labelpad=0)
  # axes[1].set_title(r'$y = \Phi(x)$', fontsize=fontsize, pad=4)
  axes[1].set_yticks([])

  fontsize = 9
  xticks_loc = np.linspace(20, signal.shape[1] - 36, 5)
  xticks = (np.linspace(0, signal.shape[1], 5) / args.frame_rate).astype(int)
  tensorboard.set_xticks(axis=axes[0],
                         ticks_loc=xticks_loc,
                         ticks=xticks,
                         fontsize=fontsize)
  tensorboard.set_xticks(axis=axes[1],
                         ticks_loc=xticks_loc,
                         ticks=xticks,
                         fontsize=fontsize)
  yticks_loc = np.linspace(0, signal.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[0],
                         ticks_loc=yticks_loc,
                         ticks=(yticks_loc + 1).astype(int),
                         fontsize=fontsize)
  axes[0].tick_params(axis='both', which='both', length=0, pad=2)
  axes[1].tick_params(axis='both', which='both', length=0, pad=2)

  cbar = figure.colorbar(cm.ScalarMappable(cmap=cm.get_cmap('gray')),
                         cax=axes[2])
  cbar.outline.set_visible(False)
  yticks_loc = np.linspace(0, 1, 6)
  tensorboard.set_yticks(axis=axes[2],
                         ticks_loc=yticks_loc,
                         ticks=np.round(yticks_loc, decimals=1),
                         fontsize=fontsize)
  tensorboard.remove_spines(axis=axes[2])
  axes[2].tick_params(axis='both', which='both', length=0)

  tensorboard.remove_spines(axis=axes[0])
  tensorboard.remove_spines(axis=axes[1])
  tensorboard.remove_spines(axis=axes[2])

  filename = os.path.join(args.output_dir, f'synthetic_sample.{args.format}')
  tensorboard.save_figure(figure=figure, filename=filename, dpi=args.dpi)

  print(f'saved plot to {filename}.')


def main(args):
  assert os.path.exists(args.dataset)

  if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

  x_ds, y_ds = get_ds(args)

  plot_augmentation(args, x_ds)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--dataset',
                      type=str,
                      default='../dataset/tfrecords/ST260/vr14/sl2048')
  parser.add_argument('--output_dir', type=str, default='plots')
  parser.add_argument('--dpi', type=int, default=120)
  parser.add_argument('--format',
                      type=str,
                      default='pdf',
                      choices=['pdf', 'png', 'svg'])
  params = parser.parse_args()

  params.synthetic_data = False
  params.neuron_order_json = None
  params.global_batch_size = 1
  if platform.system() == 'Darwin':
    matplotlib.use('TkAgg')

  main(params)
