import os
import scipy.io
import argparse
import platform
import numpy as np
import typing as t
from glob import glob

import matplotlib
if platform.system() == 'Darwin':
  matplotlib.use('TkAgg')
import seaborn as sns
import matplotlib.pyplot as plt

plt.style.use('seaborn-deep')

from cyclegan.utils import spike_helper


def load_mat(filename: str):
  content = scipy.io.loadmat(filename)
  recordings = content['VRdata']['pertrial'][0, 0]

  num_trials = np.shape(recordings['new_pos'])[1] - 1
  # note the first two rows in spikes and signals are average information
  num_neurons = np.shape(recordings['new_spikes'][0][0])[0] - 2

  # find durations of each trial & the total duration
  trial_durations = []
  for trial in range(num_trials):
    trial_durations.append(np.shape(recordings['new_pos'][0, trial])[0])
  total_duration = np.sum(trial_durations)

  # find beginnings and ends of trials
  ends = np.cumsum(trial_durations)
  starts = np.concatenate((np.array([0]), np.cumsum(trial_durations)[:-1]))

  # create empty arrays for handling recordings
  data = {
      'num_neurons': num_neurons,
      'trial': np.zeros(total_duration, dtype=np.int16),
      'position': np.zeros(total_duration, dtype=np.float32),
      'time': np.zeros(total_duration, dtype=np.float32),
      'velocity': np.zeros(total_duration, dtype=np.float32),
      'lick': np.zeros(total_duration, dtype=np.int16),
      'reward': np.zeros(total_duration, dtype=np.int16),
      'signals': np.zeros((num_neurons, total_duration), dtype=np.float32),
      'spikes': np.zeros((num_neurons, total_duration), dtype=np.int16)
  }

  for trial, (start, end) in enumerate(zip(starts, ends)):
    data['trial'][start:end] = trial
    data['position'][start:end] = recordings['new_pos'][0, trial][:, 0]
    data['time'][start:end] = recordings['new_time'][0, trial][:, 0]
    data['velocity'][start:end] = recordings['new_velocity'][0, trial][:, 0]
    data['lick'][start:end] = recordings['new_lick'][0, trial][:, 0]
    data['reward'][start:end] = recordings['new_reward'][0, trial][:, 0]
    for n in range(num_neurons):
      data['signals'][n, start:end] = recordings['new_dF'][0, trial][n + 2]
      data['spikes'][n, start:end] = recordings['new_spikes'][0, trial][n + 2]

  # reward zones are between 80cm to 100cm in the corridor
  data['reward_zone'] = np.where(
      np.logical_and(data['position'] >= 80, data['position'] <= 100), 1,
      0).astype(np.int16)

  # swap signals and spikes to have shape (time-steps, num. neurons)
  data['signals'] = np.transpose(data['signals'], axes=[1, 0])
  data['spikes'] = np.transpose(data['spikes'], axes=[1, 0])

  return data


def get_firing_rates(data: t.Dict[str, np.ndarray]):
  # convert spike trains to (num. neurons, time-steps)
  spike_trains = np.transpose(data['spikes'], axes=[1, 0])
  spike_trains = spike_helper.trains_to_neo(spike_trains)
  firing_rates = spike_helper.mean_firing_rate(spike_trains)
  return firing_rates


def get_correlation(data: t.Dict[str, np.ndarray]):
  spike_trains = np.transpose(data['spikes'], axes=[1, 0])
  diagonal = np.triu_indices(spike_trains.shape[0], k=1)
  spike_trains = spike_helper.trains_to_neo(spike_trains)
  correlation = spike_helper.correlation_coefficients(spike_trains, None)
  correlation = np.nan_to_num(correlation[diagonal])
  return correlation


def get_statistic(filename: str):
  print(f'\n\nprocessing {filename}...')
  data = load_mat(filename)

  total_duration = data['time'][-1] - data['time'][0]
  num_trials = data['trial'][-1]
  average_duration = total_duration / num_trials
  firing_rates = get_firing_rates(data)
  average_firing_rate = np.mean(firing_rates)
  num_licks = np.sum(data['lick'])
  total_rewards = np.sum(data['reward'])
  average_velocity = np.mean(data['velocity'])
  correlation = get_correlation(data)

  print(f'total duration: {total_duration:.2f}s\n'
        f'number of trials: {num_trials}\n'
        f'average duration: {average_duration:.2f}s\n'
        f'average firing rate: {average_firing_rate:.2f}Hz\n'
        f'number of licks: {num_licks}\n'
        f'total rewards: {total_rewards}\n'
        f'average velocity: {average_velocity:.2f}')


def main(args):
  if not os.path.exists(args.data_dir):
    raise FileNotFoundError(f'{args.data_dir} not found.')

  filenames = sorted(glob(os.path.join(args.data_dir, 'ST*.mat')))
  print(f'found {len(filenames)} .mat files in {args.data_dir}')

  for filename in filenames:
    get_statistic(filename)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--data_dir',
                      type=str,
                      default='../dataset/data/vr_data/ST260')
  main(parser.parse_args())
