import os
import h5py
import argparse
import platform
import matplotlib
import numpy as np
import typing as t
import seaborn as sns
from math import ceil
from tqdm import tqdm
from time import time
from glob import glob
import tensorflow as tf
from shutil import rmtree
import matplotlib.pyplot as plt

import recording_statistics
from cyclegan.utils import utils
from cyclegan.utils import dataset
from cyclegan.utils import tensorboard
from cyclegan.utils import spike_helper
from cyclegan.utils.cascade import cascade
from cyclegan.models import utils as model_utils

plt.style.use('seaborn-deep')


def firing_rate_method(args):
  assert os.path.exists(args.data_dir)

  filenames = glob(os.path.join(args.data_dir, '*.mat'))

  x_filename, y_filename = None, None
  for filename in filenames:
    if filename.endswith('AllVR1.mat'):
      x_filename = filename
    elif filename.endswith('AllVR4.mat'):
      y_filename = filename

  assert x_filename is not None and y_filename is not None

  x_data = recording_statistics.load_mat(x_filename)
  y_data = recording_statistics.load_mat(y_filename)

  x_firing_rates = recording_statistics.get_firing_rates(x_data)
  y_firing_rates = recording_statistics.get_firing_rates(y_data)

  avg_firing_rates = (x_firing_rates + y_firing_rates) / 2
  order = np.argsort(avg_firing_rates)[::-1]
  return order


def unscale(inputs: t.List[tf.Tensor], ds_min: float, ds_max: float):
  return [utils.unscale(x, ds_min=ds_min, ds_max=ds_max) for x in inputs]


def squeeze(inputs: t.List[tf.Tensor]):
  ''' squeeze last dimensions in all inputs '''
  return [tf.squeeze(x, axis=-1) for x in inputs]


def mean_squared_error(target: tf.Tensor,
                       prediction: tf.Tensor,
                       axis: t.Union[int, t.List[int]] = None):
  return tf.reduce_mean(tf.square(target - prediction), axis=axis)


def save_order(order: np.ndarray, filename: str):
  if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
  utils.save_json(filename=filename, data={'order': order})


def inference(args, x, y, model):
  x_reconstruction = model(x, training=False)
  y_reconstruction = model(y, training=False)
  x, y, x_reconstruction, y_reconstruction = unscale(
      inputs=[x, y, x_reconstruction, y_reconstruction],
      ds_min=args.ds_min,
      ds_max=args.ds_max)
  if not args.input_2d:
    x, y, x_reconstruction, y_reconstruction = squeeze(
        inputs=[x, y, x_reconstruction, y_reconstruction])
  return x, y, x_reconstruction, y_reconstruction


def get_order(args, ds, model):
  samples = {}
  for x, y in tqdm(ds,
                   desc='Inference',
                   total=args.test_steps,
                   disable=args.verbose == 0):
    x, y, x_reconstruction, y_reconstruction = inference(args,
                                                         x=x['data'],
                                                         y=y['data'],
                                                         model=model)
    utils.update_dict(
        samples, {
            'x': x,
            'y': y,
            'x_reconstruction': x_reconstruction,
            'y_reconstruction': y_reconstruction
        })
  samples = {k: tf.concat(v, axis=0) for k, v in samples.items()}
  # calculate neuron-wise mean squared error
  x_error = mean_squared_error(target=samples['x'],
                               prediction=samples['x_reconstruction'],
                               axis=[0, 1])
  y_error = mean_squared_error(target=samples['y'],
                               prediction=samples['y_reconstruction'],
                               axis=[0, 1])
  total_error = 0.5 * (x_error + y_error)
  order = tf.argsort(total_error).numpy()
  return order, samples


def get_coordinates(filename: str):
  assert os.path.exists(filename)
  coordinates = []
  with h5py.File(filename, 'r') as file:
    rois = file['data'][()]
    for roi in rois[2:]:
      coordinates.append(file[roi[0]]['mnCoordinates'][()])
  return coordinates


def plot_comparison(args, ds, model, summary, epoch: int):
  for x, y in ds:
    x, y, x_reconstruction, y_reconstruction = inference(args,
                                                         x=x['data'],
                                                         y=y['data'],
                                                         model=model)
    summary.plot_comparison(f'x_reconstruction',
                            traces=[x[0], x_reconstruction[0]],
                            labels=['input', 'reconstruction'],
                            step=epoch,
                            mode=1)
    summary.plot_comparison(f'y_reconstruction',
                            traces=[y[0], y_reconstruction[0]],
                            labels=['input', 'reconstruction'],
                            step=epoch,
                            mode=1)


def plot_coordinates(args, order: np.ndarray, summary, epoch: int):
  if not hasattr(args, 'coordinates'):
    args.coordinates = get_coordinates(filename=args.coordinate_filename)
  coordinates = args.coordinates
  if len(coordinates) != len(order):
    return
  centers = [c.mean(axis=1) for c in coordinates]
  colors = sns.color_palette('husl', len(centers))
  figure, axis = plt.subplots(nrows=1,
                              ncols=1,
                              gridspec_kw={
                                  'wspace': 0.01,
                                  'hspace': 0.01
                              },
                              figsize=(5, 5),
                              dpi=args.dpi)
  axis.set_facecolor('black')
  for i, neuron in enumerate(order):
    axis.plot(*coordinates[neuron], color=colors[i], alpha=0.5)
    axis.text(*centers[neuron],
              s=f'{i+1}',
              color=colors[i],
              horizontalalignment='center',
              verticalalignment='center')
  tensorboard.remove_ticks(axis=axis)
  summary.figure(tag=f'coordinates', figure=figure, step=epoch, mode=1)


def plot_raster(args,
                order: np.ndarray,
                samples: t.Dict[str, np.ndarray],
                summary: tensorboard.Summary,
                epoch: int,
                num_plots: int = 5):
  if args.x_spike_rates is None or args.y_spike_rates is None:
    if args.verbose:
      print('deconvolve signals...')
    args.x_spike_rates = cascade.deconvolve_signals(samples['x'][:num_plots])
    args.y_spike_rates = cascade.deconvolve_signals(samples['y'][:num_plots])

  x_unordered = spike_helper.get_spike_trains(args.x_spike_rates)
  y_unordered = spike_helper.get_spike_trains(args.y_spike_rates)
  # convert spike trains to (NCW)
  x_unordered = np.transpose(x_unordered, axes=[0, 2, 1])
  y_unordered = np.transpose(y_unordered, axes=[0, 2, 1])

  x_ordered = x_unordered[:, order, :]
  y_ordered = y_unordered[:, order, :]

  # set raster plot y-axis label to be "new order | original order"
  yticks_loc = np.linspace(0, len(order) - 1, 6)
  original_order = yticks_loc.astype(int)
  new_order = order[original_order]
  original_order, new_order = original_order + 1, new_order + 1
  yticks = [
      f'{new_order[i]:03d}|{original_order[i]:03d}'
      for i in range(len(new_order))
  ][::-1]

  for i in range(num_plots):
    summary.raster_plot(f'x_raster_plot/trial_{i}',
                        spikes1=x_unordered[i],
                        spikes2=x_ordered[i],
                        xlabel='Time (s)',
                        ylabel='Neuron',
                        legends=['unordered', 'ordered'],
                        step=epoch,
                        mode=1,
                        yticks_loc=yticks_loc,
                        yticks=yticks)
    summary.raster_plot(f'y_raster_plot/trial_{i}',
                        spikes1=y_unordered[i],
                        spikes2=y_ordered[i],
                        xlabel='Time (s)',
                        ylabel='Neuron',
                        legends=['unordered', 'ordered'],
                        step=epoch,
                        mode=1,
                        yticks_loc=yticks_loc,
                        yticks=yticks)


def log_progress(args, test_ds, sample_ds, model, summary, epoch: int):
  plot_comparison(args, ds=sample_ds, model=model, summary=summary, epoch=epoch)
  order, samples = get_order(args, ds=test_ds, model=model)
  save_order(order=order,
             filename=os.path.join(args.output_dir, 'history',
                                   f'epoch_{epoch:03d}.json'))
  plot_coordinates(args, order=order, summary=summary, epoch=epoch)
  plot_raster(args, order=order, samples=samples, summary=summary, epoch=epoch)


def conv_block(inputs,
               filters: int,
               kernel_size: int,
               normalization: str = 'batchnorm',
               activation: str = 'lrelu',
               dropout: float = 0.0,
               name: str = 'conv_block'):
  outputs = model_utils.Conv(filters,
                             kernel_size,
                             strides=1,
                             padding='same',
                             name=f'{name}/conv_1')(inputs)
  outputs = model_utils.Normalization(normalization,
                                      name=f'{name}/norm_1')(outputs)
  outputs = model_utils.Activation(activation,
                                   name=f'{name}/activation_1')(outputs)
  outputs = model_utils.SpatialDropout(dropout,
                                       name=f'{name}/dropout_1')(outputs)

  outputs = model_utils.Conv(filters,
                             kernel_size,
                             strides=1,
                             padding='same',
                             name=f'{name}/conv_2')(outputs)
  outputs = model_utils.Normalization(normalization,
                                      name=f'{name}/norm_2')(outputs)
  outputs = model_utils.Activation(activation,
                                   name=f'{name}/activation_2')(outputs)
  outputs = model_utils.SpatialDropout(dropout,
                                       name=f'{name}/dropout_2')(outputs)

  return outputs


def get_model(args,
              filters: int = 32,
              num_blocks: int = 3,
              normalization: str = 'instancenorm',
              activation: str = 'lrelu',
              dropout: float = 0.0,
              reduction_factor: t.Union[int, t.Tuple[int, int]] = 2):
  inputs = tf.keras.Input(args.input_shape, name='inputs')
  outputs = inputs

  shortcuts = []

  # encoder
  for i in range(num_blocks):
    outputs = conv_block(outputs,
                         filters=filters,
                         kernel_size=3,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         name=f'down_{i}/conv_block')
    shortcuts.append(outputs)
    outputs = model_utils.MaxPool(pool_size=reduction_factor,
                                  name=f'down_{i}/max_pool')(outputs)
    filters *= 2

  # bottleneck
  outputs = conv_block(outputs,
                       filters=filters,
                       kernel_size=3,
                       normalization=normalization,
                       activation=activation,
                       dropout=dropout,
                       name=f'bottleneck')
  shortcuts = shortcuts[::-1]

  # decoder
  for i in range(num_blocks):
    filters /= 2
    outputs = model_utils.TransposeConv(filters=filters,
                                        kernel_size=2,
                                        strides=reduction_factor,
                                        padding='same',
                                        name=f'up_{i}/transpose')(outputs)
    if outputs.shape[1:] != shortcuts[i].shape[1:]:
      outputs = model_utils.cropping(outputs, shortcuts[i], name=f'up_{i}/crop')
    outputs = conv_block(outputs,
                         filters=filters,
                         kernel_size=3,
                         normalization=normalization,
                         activation=activation,
                         dropout=dropout,
                         name=f'up_{i}/conv_block')

  outputs = model_utils.Conv(filters=args.input_shape[-1],
                             kernel_size=1,
                             name='output/conv')(outputs)

  outputs = model_utils.Activation('sigmoid',
                                   dtype=tf.float32,
                                   name='output/activation')(outputs)

  return tf.keras.Model(inputs=inputs, outputs=outputs, name='autoencoder')


@tf.function
def train_step(x, y, model, optimizer):
  with tf.GradientTape() as tape:
    x_reconstruction = model(x, training=True)
    y_reconstruction = model(y, training=True)
    x_loss = mean_squared_error(target=x, prediction=x_reconstruction)
    y_loss = mean_squared_error(target=y, prediction=y_reconstruction)
    loss = 0.5 * (x_loss + y_loss)
  optimizer.minimize(loss, var_list=model.trainable_variables, tape=tape)
  return {'loss/x_loss': x_loss, 'loss/y_loss': y_loss, 'loss/total_loss': loss}


def train(args, ds, model, optimizer, summary, epoch: int):
  results = {}
  for x, y in tqdm(ds,
                   desc='Train',
                   total=args.train_steps,
                   disable=args.verbose == 0):
    result = train_step(x=x['data'],
                        y=y['data'],
                        model=model,
                        optimizer=optimizer)
    utils.update_dict(results, result)
  for k, v in results.items():
    results[k] = tf.reduce_mean(v)
    summary.scalar(k, results[k], step=epoch, mode=0)
  return results


@tf.function
def validation_step(x, y, model):
  x_reconstruction = model(x, training=False)
  y_reconstruction = model(y, training=False)
  x_loss = mean_squared_error(target=x, prediction=x_reconstruction)
  y_loss = mean_squared_error(target=y, prediction=y_reconstruction)
  loss = 0.5 * (x_loss + y_loss)
  return {'loss/x_loss': x_loss, 'loss/y_loss': y_loss, 'loss/total_loss': loss}


def validate(args, ds, model, summary, epoch: int):
  results = {}
  for x, y in tqdm(ds,
                   desc='Validation',
                   total=args.val_steps,
                   disable=args.verbose == 0):
    result = validation_step(x=x['data'], y=y['data'], model=model)
    utils.update_dict(results, result)
  for k, v in results.items():
    results[k] = tf.reduce_mean(v)
    summary.scalar(k, results[k], step=epoch, mode=1)
  return results


def autoencoder_method(args):
  assert os.path.exists(args.dataset), f'{args.dataset} not found.'

  train_ds, val_ds, test_ds, sample_ds = dataset.get_datasets(args)

  summary = tensorboard.Summary(args)

  model = get_model(args,
                    filters=args.num_filters,
                    normalization=args.normalization,
                    activation=args.activation,
                    dropout=args.dropout)
  model_summary = utils.model_summary(args, model)
  summary.scalar(f'model/trainable_parameters',
                 utils.count_trainable_params(model))
  if args.verbose == 2:
    print(model_summary)
  optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)

  utils.save_args(args)

  epoch = 0
  args.x_spike_rates, args.y_spike_rates = None, None

  log_progress(args,
               test_ds=test_ds,
               sample_ds=sample_ds,
               model=model,
               summary=summary,
               epoch=epoch)

  while (epoch := epoch + 1) < args.epochs + 1:
    print(f'Epoch {epoch:03d}/{args.epochs:03d}')

    start = time()
    train_result = train(args,
                         ds=train_ds,
                         model=model,
                         optimizer=optimizer,
                         epoch=epoch,
                         summary=summary)
    val_result = validate(args,
                          ds=val_ds,
                          model=model,
                          epoch=epoch,
                          summary=summary)
    elapse = time() - start

    summary.scalar('model/elapse', elapse, step=epoch, mode=0)

    print(f'Train\t\tX loss: {train_result["loss/x_loss"]:.4e}\t'
          f'Y loss: {train_result["loss/y_loss"]:.4e}\t'
          f'loss: {train_result["loss/total_loss"]:.4e}\n'
          f'Validation\tX loss: {val_result["loss/y_loss"]:.4e}\t'
          f'Y loss: {val_result["loss/y_loss"]:.4e}\t'
          f'loss: {val_result["loss/total_loss"]:.4e}\n'
          f'Elapse: {elapse:.02f}s\n')

    if epoch % 10 == 0 or epoch == args.epochs:
      log_progress(args,
                   test_ds=test_ds,
                   sample_ds=sample_ds,
                   model=model,
                   summary=summary,
                   epoch=epoch)

  order, samples = get_order(args, ds=test_ds, model=model)

  # test results
  x_test_loss = mean_squared_error(target=samples['x'],
                                   prediction=samples['x_reconstruction'])
  y_test_loss = mean_squared_error(target=samples['y'],
                                   prediction=samples['y_reconstruction'])
  summary.scalar('test_results/x_loss', x_test_loss, step=epoch, mode=1)
  summary.scalar('test_results/y_loss', y_test_loss, step=epoch, mode=1)
  summary.scalar('test_results/total_loss',
                 0.5 * (x_test_loss + y_test_loss),
                 step=epoch,
                 mode=1)

  return order


def main(args):
  np.random.seed(args.seed)
  tf.random.set_seed(args.seed)

  if args.clear_output_dir and os.path.exists(args.output_dir):
    rmtree(args.output_dir)
  if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

  args.global_batch_size = args.batch_size
  args.order_json = os.path.join(args.output_dir, 'order.json')

  if args.method == 'firing_rate':
    order = firing_rate_method(args)
  else:
    order = autoencoder_method(args)

  utils.update_json(args.order_json, data={'order': order})

  print(f'order saved to {args.order_json}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()

  # data hyper-parameters
  parser.add_argument('--dataset', type=str, default='')
  parser.add_argument('--data_dir',
                      type=str,
                      default='../dataset/data/vr_data/ST260',
                      help='raw data directory')
  parser.add_argument('--coordinate_filename',
                      type=str,
                      default='../dataset/data/vr_data/MC_20181117_P01.mat')
  parser.add_argument('--output_dir', type=str, default='runs')

  # model hyper-parameters
  parser.add_argument('--num_filters', type=int, default=64)
  parser.add_argument('--normalization', type=str, default='instancenorm')
  parser.add_argument('--activation', type=str, default='lrelu')
  parser.add_argument('--dropout', type=float, default=0.2)
  parser.add_argument('--input_2d',
                      action='store_true',
                      help='use 2D input instead of 3D')

  # training settings
  parser.add_argument('--method',
                      choices=['autoencoder', 'firing_rate'],
                      default='autoencoder')
  parser.add_argument('--epochs', default=200, type=int)
  parser.add_argument('--batch_size', default=32, type=int)
  parser.add_argument('--learning_rate', default=1e-3, type=float)

  # plot settings
  parser.add_argument('--num_plots', default=5, type=int)
  parser.add_argument('--dpi', type=int, default=120)
  parser.add_argument('--save_plots',
                      action='store_true',
                      help='save plots to disk')
  parser.add_argument('--format',
                      type=str,
                      default='pdf',
                      help='file format when --save_plots')

  # misc
  parser.add_argument('--clear_output_dir', action='store_true')
  parser.add_argument('--seed', type=int, default=777)
  parser.add_argument('--verbose', default=1, type=int)
  params = parser.parse_args()
  params.synthetic_data = False
  params.neuron_order_json = None

  if params.verbose == 2 and platform.system() == 'Darwin':
    matplotlib.use('TkAgg')

  main(params)
