import os
import argparse
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
from multiprocessing import Pool

from cyclegan.utils import utils
from cyclegan.utils import h5 as h5
from cyclegan.utils import spike_helper
from cyclegan.utils.tensorboard import Summary

warnings.simplefilter(action='ignore', category=UserWarning)
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter(action='ignore', category=RuntimeWarning)
warnings.simplefilter(action='ignore', category=DeprecationWarning)

np.random.seed(1234)


def get_legend(name: str) -> str:
  legend = name
  if name == 'fake_x':
    legend = r'$\hat{x}$ = F(y)'
  elif name == 'fake_y':
    legend = r'$\hat{y}$ = G(x)'
  elif name == 'cycle_x':
    legend = r'$\bar{x}$ = F(G(x))'
  elif name == 'cycle_y':
    legend = r'$\bar{y}$ = G(F(y))'
  return legend


def get_neo_trains(args,
                   key: str,
                   data_format: str,
                   neuron: int = None,
                   trial: int = None):
  assert data_format in ['NW', 'CW']
  assert neuron is not None or trial is not None
  spikes = h5.get(args.spikes_filename, key, neuron=neuron, trial=trial)
  if data_format == 'CW':
    spikes = np.swapaxes(spikes, axis1=0, axis2=1)
  return spike_helper.trains_to_neo(spikes)


def plot_raster(args, key1: str, key2: str, summary):
  trial = np.random.randint(args.num_samples)

  spike_rates1 = h5.get(args.spikes_filename, key1, trial=trial)
  spike_trains1 = spike_helper.get_spike_trains(spike_rates1)

  spike_rates2 = h5.get(args.spikes_filename, key2, trial=trial)
  spike_trains2 = spike_helper.get_spike_trains(spike_rates2)

  spike_trains1 = np.flip(spike_trains1, axis=1)
  spike_trains2 = np.flip(spike_trains2, axis=1)

  summary.raster_plot(f'{key1}|{key2}/raster/trial_{trial:03d}',
                      spikes1=spike_trains1,
                      spikes2=spike_trains2,
                      xlabel='Time (s)',
                      ylabel='Neuron',
                      legends=[get_legend(key1),
                               get_legend(key2)])


def kl_divergence(p, q):
  # replace entries with 0 probability with 1e-10
  p = np.where(p == 0, 1e-10, p)
  q = np.where(q == 0, 1e-10, q)
  return np.sum(p * np.log(p / q))


def pairs_kl_divergence(pairs):
  kl = np.zeros((len(pairs),), dtype=np.float32)
  for i in range(len(pairs)):
    samples1, sample2 = pairs[i]

    df = pd.DataFrame({
        'data': np.concatenate([samples1, sample2]),
        'samples1': [True] * len(samples1) + [False] * len(sample2)
    })

    num_bins = 30
    df['bins'] = pd.cut(df.data, bins=num_bins, labels=np.arange(num_bins))

    pdf1 = [
        len(df[(df.bins == i) & (df.samples1 == True)]) for i in range(num_bins)
    ]
    pdf1 = np.array(pdf1, dtype=np.float32) / len(samples1)

    pdf2 = [
        len(df[(df.bins == i) & (df.samples1 == False)])
        for i in range(num_bins)
    ]
    pdf2 = np.array(pdf2, dtype=np.float32) / len(sample2)

    kl[i] = kl_divergence(pdf1, pdf2)
  return kl


def sort_heatmap(matrix):
  """ sort the given matrix where the top left corner is the minimum """
  num_trials = len(matrix)

  # create a copy of distances matrix for modification
  matrix_copy = np.copy(matrix)

  heatmap = np.full(matrix.shape, fill_value=np.nan, dtype=np.float32)

  # get the index with the minimum value
  min_index = np.unravel_index(np.argmin(matrix), matrix.shape)

  # row and column order for the sorted matrix
  row_order = np.full((num_trials,), fill_value=-1, dtype=np.int)
  row_order[0] = min_index[0]
  column_order = np.argsort(matrix[min_index[0]])

  for i in range(num_trials):
    if i != 0:
      row_order[i] = np.argsort(matrix_copy[:, column_order[i]])[0]
    heatmap[i] = matrix[row_order[i]][column_order]
    matrix_copy[row_order[i]][:] = np.inf

  return heatmap, row_order, column_order


def firing_rate(args, key1: str, key2: str, neuron: int):
  spikes1 = get_neo_trains(args, key1, data_format='NW', neuron=neuron)
  spikes2 = get_neo_trains(args, key2, data_format='NW', neuron=neuron)
  firing_rate1 = spike_helper.mean_firing_rate(spikes1)
  firing_rate2 = spike_helper.mean_firing_rate(spikes2)
  return firing_rate1, firing_rate2


def measure_firing_rate(args, key1: str, key2: str, summary):
  print(f'measure firing rate between {key1} and {key2}')

  pool = Pool(args.num_processors)
  results = pool.starmap(
      firing_rate,
      [(args, key1, key2, neuron) for neuron in range(args.num_neurons)])
  pool.close()

  # plot histograms of specific neurons
  for neuron in args.neurons:
    index = args.reverse_order[neuron]
    summary.plot_histogram(f'{key1}|{key2}/firing_rate/neuron_{neuron:03d}',
                           data={
                               get_legend(key1): results[index][0],
                               get_legend(key2): results[index][1]
                           },
                           xlabel='Hz',
                           ylabel='Count',
                           title=f'Neuron {neuron}')

  kl = pairs_kl_divergence(results)
  summary.plot_histogram(f'{key1}|{key2}/firing_rate/kl',
                         data={'kl': kl},
                         xlabel='KL divergence',
                         ylabel='Count',
                         title='Firing Rate',
                         annotate_statistic=True)


def correlation(args, key1: str, key2: str, trial: int):
  diagonal = np.triu_indices(args.num_neurons, k=1)
  spikes1 = get_neo_trains(args, key1, data_format='CW', trial=trial)
  spikes2 = get_neo_trains(args, key2, data_format='CW', trial=trial)
  correlation1 = spike_helper.correlation_coefficients(spikes1, None)
  correlation1 = np.nan_to_num(correlation1[diagonal])
  correlation2 = spike_helper.correlation_coefficients(spikes2, None)
  correlation2 = np.nan_to_num(correlation2[diagonal])
  return correlation1, correlation2


def measure_correlation(args, key1: str, key2: str, summary):
  print(f'measure correlations of {key1} and {key2}')

  pool = Pool(args.num_processors)
  results = pool.starmap(
      correlation,
      [(args, key1, key2, trial)
       for trial in range(h5.get_length(args.spikes_filename, key1))])
  pool.close()

  for trial in args.trials:
    summary.plot_histogram(f'{key1}|{key2}/correlation/trial_{trial:03d}',
                           data={
                               get_legend(key1): results[trial][0],
                               get_legend(key2): results[trial][1]
                           },
                           xlabel='Pair-wise correlation',
                           ylabel='Count',
                           title=f'Trial {trial}')

  kl = pairs_kl_divergence(results)
  summary.plot_histogram(f'{key1}|{key2}/correlation/kl',
                         data={'kl': kl},
                         xlabel='KL divergence',
                         ylabel='Count',
                         title='Correlation KL',
                         annotate_statistic=True)


def van_rossum_heatmap(args, key1: str, key2: str, neuron: int):
  spikes1 = get_neo_trains(args, key1, data_format='NW', neuron=neuron)
  spikes2 = get_neo_trains(args, key2, data_format='NW', neuron=neuron)
  distances = spike_helper.van_rossum_distance(spikes1, spikes2)
  heatmap, row_order, column_order = sort_heatmap(distances)
  return {
      'heatmap': heatmap,
      'xticklabels': row_order,
      'yticklabels': column_order
  }


def van_rossum_pairwise(args, key1: str, key2: str, trial: int):
  ''' Pairwise van Rossum distance of a given trial '''
  spikes1 = get_neo_trains(args, key=key1, data_format='CW', trial=trial)
  van_rossum1 = spike_helper.van_rossum_distance(spikes1, None)
  spikes2 = get_neo_trains(args, key=key2, data_format='CW', trial=trial)
  van_rossum2 = spike_helper.van_rossum_distance(spikes2, None)
  diag_indices = np.triu_indices(len(van_rossum1), k=1)
  return van_rossum1[diag_indices], van_rossum2[diag_indices]


def measure_van_rossum(args, key1: str, key2: str, summary):
  print(f'measure van Rossum distance of {key1} and {key2}')

  indexes = [args.reverse_order[neuron] for neuron in args.neurons]

  # generate van Rossum distance heatmap
  pool = Pool(args.num_processors)
  results = pool.starmap(van_rossum_heatmap,
                         [(args, key1, key2, neuron) for neuron in indexes])
  pool.close()

  assert len(results) == len(args.neurons)

  for i, result in enumerate(results):
    summary.plot_heatmap(
        f'{key1}|{key2}/van_rossum/neuron_{args.neurons[i]:03d}',
        matrix=result['heatmap'],
        xlabel=get_legend(key2),
        ylabel=get_legend(key1),
        xticklabels=result['xticklabels'],
        yticklabels=result['yticklabels'],
        title=f'Neuron {args.neurons[i]}')

  # compute pairwise van Rossum distance KL divergence
  pool = Pool(args.num_processors)
  van_rossum_pairs = pool.starmap(
      van_rossum_pairwise,
      [(args, key1, key2, t) for t in range(args.num_samples)])
  pool.close()

  kl = pairs_kl_divergence(van_rossum_pairs)
  summary.plot_histogram(f'{key1}|{key2}/van_rossum/kl',
                         data={'kl': kl},
                         xlabel='KL divergence',
                         ylabel='Count',
                         title='van-Rossum distance KL',
                         annotate_statistic=True)


def measure_pair(args, key1: str, key2: str, summary):
  plot_raster(args, key1, key2, summary)
  measure_firing_rate(args, key1, key2, summary)
  measure_correlation(args, key1, key2, summary)
  measure_van_rossum(args, key1, key2, summary)


def main(args):
  if not os.path.exists(args.output_dir):
    raise FileNotFoundError(f'{args.output_dir} not found.')

  args.samples_dir = os.path.join(args.output_dir, 'samples')
  args.signals_filename = os.path.join(args.samples_dir, 'signals.h5')
  args.spikes_filename = os.path.join(args.samples_dir, 'spikes.h5')
  utils.load_args(args)

  if not os.path.exists(args.spikes_filename):
    spike_helper.deconvolve_samples(args)
    exit()

  # neurons and trials to plot
  neuron_indexes = np.linspace(0, args.num_neurons - 1, 5, dtype=int)
  args.neurons = neuron_indexes[1:-1]
  args.trials = [0, 5, 10]
  args.num_samples = h5.get_length(args.spikes_filename, 'x')
  args.reverse_order = np.argsort(args.neuron_order)

  summary = Summary(args, analysis=True)

  measure_pair(args, key1='x', key2='cycle_x', summary=summary)
  measure_pair(args, key1='x', key2='fake_x', summary=summary)
  measure_pair(args, key1='y', key2='cycle_y', summary=summary)
  measure_pair(args, key1='y', key2='fake_y', summary=summary)

  summary.close()

  print(f'results saved in {os.path.join(args.output_dir, "analysis")}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--output_dir', default='runs/pycharm', type=str)
  parser.add_argument('--num_processors', default=8, type=int)
  parser.add_argument('--num_neuron_plots', default=6, type=int)
  parser.add_argument('--num_trial_plots', default=6, type=int)
  parser.add_argument('--plots_per_row', default=3, type=int)
  parser.add_argument('--dpi', default=120, type=int)
  parser.add_argument('--format', default='pdf', choices=['pdf', 'png', 'svg'])
  parser.add_argument('--save_plots', action='store_true')
  parser.add_argument('--seed', type=int, default=1234)
  parser.add_argument('--verbose', default=1, type=int)
  params = parser.parse_args()

  main(params)
