import os
import scipy.io
import argparse
import numpy as np
import typing as t
from math import ceil
import tensorflow as tf

from shutil import rmtree

from cyclegan.utils import utils

# use CPU only
os.environ["CUDA_VISIBLE_DEVICES"] = ""

FRAMERATE = 24.0


def calculate_num_per_shard(args):
  """ 
  calculate the number of data per shard given sequence_length such that each 
  shard is target_size GB
  """
  # 1GB shard
  num_per_shard = ceil((120 / args.sequence_length) * 2240) * 10
  return int(num_per_shard * args.target_shard_size)


def load_mat(args, filename: str):
  content = scipy.io.loadmat(filename)
  recordings = content['VRdata']['pertrial'][0, 0]

  num_trials = np.shape(recordings['new_pos'])[1] - 1
  # note the first two rows in spikes and signals are average information
  num_neurons = np.shape(recordings['new_spikes'][0][0])[0] - 2

  # find durations of each trial & the total duration
  trial_durations = []
  for trial in range(num_trials):
    trial_durations.append(np.shape(recordings['new_pos'][0, trial])[0])
  total_duration = np.sum(trial_durations)

  # find beginnings and ends of trials
  ends = np.cumsum(trial_durations)
  starts = np.concatenate((np.array([0]), np.cumsum(trial_durations)[:-1]))

  # create empty arrays for handling recordings
  data = {
      'num_neurons': num_neurons,
      'trial': np.zeros(total_duration, dtype=np.int16),
      'position': np.zeros(total_duration, dtype=np.float32),
      'time': np.zeros(total_duration, dtype=np.float32),
      'velocity': np.zeros(total_duration, dtype=np.float32),
      'lick': np.zeros(total_duration, dtype=np.int16),
      'reward': np.zeros(total_duration, dtype=np.int16),
      'signals': np.zeros((num_neurons, total_duration), dtype=np.float32),
      'spikes': np.zeros((num_neurons, total_duration), dtype=np.int16)
  }

  for trial, (start, end) in enumerate(zip(starts, ends)):
    data['trial'][start:end] = trial
    data['position'][start:end] = recordings['new_pos'][0, trial][:, 0]
    data['time'][start:end] = recordings['new_time'][0, trial][:, 0]
    data['velocity'][start:end] = recordings['new_velocity'][0, trial][:, 0]
    data['lick'][start:end] = recordings['new_lick'][0, trial][:, 0]
    data['reward'][start:end] = recordings['new_reward'][0, trial][:, 0]
    for n in range(num_neurons):
      data['signals'][n, start:end] = recordings['new_dF'][0, trial][n + 2]
      data['spikes'][n, start:end] = recordings['new_spikes'][0, trial][n + 2]

  # reward zones are between 120cm to 140cm in the corridor
  reward_zone = np.zeros_like(data['trial'])
  for trial in range(num_trials):
    # need to iterate every single trial because sometimes the mouse move
    # backward hence cannot simply do np.where(120 <= data['position'] <= 140)
    index = np.where(data['trial'] == trial)[0]
    position = data['position'][index[0]:index[-1]]
    first = np.where(position >= 120)[0][0]
    last = np.where(position <= 140)[0][-1]
    reward_zone[index[0] + first:index[0] + last + 1] = 1
  data['reward_zone'] = reward_zone

  # swap signals and spikes to have shape (time-steps, num. neurons)
  data['signals'] = np.transpose(data['signals'], axes=[1, 0])
  data['spikes'] = np.transpose(data['spikes'], axes=[1, 0])

  return data


def segmentation(args, data: t.Dict[str, np.ndarray]):
  segments = {
      'signals': [],
      'rewards': [],
      'licks': [],
      'reward_zones': [],
      'positions': [],
      'trials': []
  }

  # sliding window with step size 1
  i = 0
  while i + args.sequence_length < data['signals'].shape[0]:
    start, end = i, i + args.sequence_length
    segments['signals'].append(data['signals'][start:end, :])
    segments['rewards'].append(data['reward'][start:end])
    segments['licks'].append(data['lick'][start:end])
    segments['reward_zones'].append(data['reward_zone'][start:end])
    segments['positions'].append(data['position'][start:end])
    segments['trials'].append(data['trial'][start:end])
    i += 1

  segments = {k: np.stack(v, axis=0) for k, v in segments.items()}

  # sub-sample segments to the desired number of samples
  num_segments = segments['signals'].shape[0]
  num_samples = args.train_size + args.val_size + args.test_size
  assert num_segments > num_samples
  indexes = np.linspace(start=0,
                        stop=num_segments - 1,
                        num=num_samples,
                        dtype=int)
  segments = {k: v[indexes] for k, v in segments.items()}

  return segments


def load_and_segment(args, filename: str):
  assert os.path.exists(filename) and filename.endswith('.mat')
  print(f'processing file {filename}...')
  data = load_mat(args, filename=filename)
  segments = segmentation(args, data=data)
  return segments


def scale_data(
    args,
    ds1: t.Dict[str, np.ndarray],
    ds2: t.Dict[str, np.ndarray],
):
  """ scale ds1 and ds2 to [0, 1] args.scale_data is True """
  print(f'dataset X min {np.min(ds1["signals"]):.04f}\t'
        f'max {np.max(ds1["signals"]):.04f}\t'
        f'mean {np.mean(ds1["signals"]):.04f}\n'
        f'dataset Y min {np.min(ds2["signals"]):.04f}\t'
        f'max {np.max(ds2["signals"]):.04f}\t'
        f'mean {np.mean(ds2["signals"]):.04f}')

  args.ds_min = min(np.min(ds1["signals"]), np.min(ds2["signals"]))
  args.ds_max = max(np.max(ds1["signals"]), np.max(ds2["signals"]))

  if args.scale_data:
    print('\nscale data to [0, 1]...')
    ds1["signals"] = utils.scale(ds1["signals"], args.ds_min, args.ds_max)
    ds2["signals"] = utils.scale(ds2["signals"], args.ds_min, args.ds_max)
    print(f'dataset X min {np.min(ds1["signals"]):.04f}\t'
          f'max {np.max(ds1["signals"]):.04f}\t'
          f'mean {np.mean(ds1["signals"]):.04f}\n'
          f'dataset Y min {np.min(ds2["signals"]):.04f}\t'
          f'max {np.max(ds2["signals"]):.04f}\t'
          f'mean {np.mean(ds2["signals"]):.04f}')


def shuffle(ds: t.Dict[str, np.ndarray], rng: np.random.Generator):
  ''' shuffle all items in ds in the same order '''
  indexes = np.arange(len(ds['signals']))
  rng.shuffle(indexes)
  for key in ds.keys():
    ds[key] = ds[key][indexes]


def get_record_filename(args, prefix: str, shard_id: int, num_shards: int):
  filename = f'{prefix}-{shard_id + 1:03d}-of-{num_shards:03d}.record'
  return os.path.join(args.output_dir, filename)


def serialize_example(signal: np.ndarray, reward: np.ndarray, lick: np.ndarray,
                      reward_zone: np.ndarray, position: np.ndarray,
                      trial: np.ndarray):

  def bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

  example = tf.train.Example(features=tf.train.Features(
      feature={
          'signal': bytes_feature(signal.tobytes()),
          'reward': bytes_feature(reward.tobytes()),
          'lick': bytes_feature(lick.tobytes()),
          'reward_zone': bytes_feature(reward_zone.tobytes()),
          'position': bytes_feature(position.tobytes()),
          'trial': bytes_feature(trial.tobytes())
      }))
  return example.SerializeToString()


def write_to_records(args, ds: t.Dict[str, np.ndarray], prefix: str):
  num_samples = len(ds['signals'])

  # calculate the number of records to create
  num_shards = 1 if args.num_per_shard == 0 else ceil(num_samples /
                                                      args.num_per_shard)

  print(f'\nwriting {num_samples} segments to {num_shards} records...')

  shards_ds, ranges = [], utils.split_index(num_samples, n=num_shards)
  for (start, end) in ranges:
    shards_ds.append({k: v[start:end] for k, v in ds.items()})

  for shard_id in range(num_shards):
    filename = get_record_filename(args, prefix, shard_id, num_shards)
    shard_ds = shards_ds[shard_id]
    print(f'\twriting {len(shard_ds["signals"])} segments to {filename}...')
    with tf.io.TFRecordWriter(filename) as writer:
      for i in range(len(shard_ds['signals'])):
        example = serialize_example(signal=shard_ds['signals'][i],
                                    reward=shard_ds['rewards'][i],
                                    lick=shard_ds['licks'][i],
                                    reward_zone=shard_ds['reward_zones'][i],
                                    position=shard_ds['positions'][i],
                                    trial=shard_ds['trials'][i])
        writer.write(example)


def main(args):
  if os.path.exists(args.output_dir) and args.overwrite:
    rmtree(args.output_dir)

  os.makedirs(args.output_dir)

  ds1 = load_and_segment(args, filename=args.input_x)
  ds2 = load_and_segment(args, filename=args.input_y)

  assert sorted(ds1.keys()) == sorted(ds2.keys())
  assert ds1['signals'].shape[1:] == ds2['signals'].shape[1:]

  args.signal_shape = tuple(ds1['signals'].shape[1:])
  args.num_neurons = ds1['signals'].shape[2]

  scale_data(args, ds1=ds1, ds2=ds2)

  # shuffle the two sets independently
  rng = np.random.default_rng(seed=args.seed)
  shuffle(ds=ds1, rng=rng)
  shuffle(ds=ds2, rng=rng)

  train_range = [0, args.train_size]
  val_range = [args.train_size, args.train_size + args.val_size]
  test_range = [args.train_size + args.val_size, ds1['signals'].shape[0]]

  args.num_per_shard = calculate_num_per_shard(args)

  # save train set
  write_to_records(
      args,
      ds={k: v[train_range[0]:train_range[1]] for k, v in ds1.items()},
      prefix='x-train')
  write_to_records(
      args,
      ds={k: v[train_range[0]:train_range[1]] for k, v in ds2.items()},
      prefix='y-train')
  # save validation set
  write_to_records(args,
                   ds={k: v[val_range[0]:val_range[1]] for k, v in ds1.items()},
                   prefix='x-val')
  write_to_records(args,
                   ds={k: v[val_range[0]:val_range[1]] for k, v in ds2.items()},
                   prefix='y-val')
  # save test set
  write_to_records(
      args,
      ds={k: v[test_range[0]:test_range[1]] for k, v in ds1.items()},
      prefix='x-test')
  write_to_records(
      args,
      ds={k: v[test_range[0]:test_range[1]] for k, v in ds2.items()},
      prefix='y-test')
  # save sample set
  write_to_records(args,
                   ds={k: v[:args.sample_size] for k, v in ds1.items()},
                   prefix='x-sample')
  write_to_records(args,
                   ds={k: v[:args.sample_size] for k, v in ds2.items()},
                   prefix='y-sample')

  # save information of the dataset
  utils.save_json(
      os.path.join(args.output_dir, 'info.json'),
      data={
          'train_sizes': [args.train_size, args.train_size],
          'val_sizes': [args.val_size, args.val_size],
          'test_sizes': [args.test_size, args.test_size],
          'signal_shape': args.signal_shape,
          'sequence_length': args.sequence_length,
          'num_neurons': args.num_neurons,
          'scaled_data': args.scale_data,
          'x_train_prefix': 'x-train',
          'x_val_prefix': 'x-val',
          'x_test_prefix': 'x-test',
          'x_sample_prefix': 'x-sample',
          'y_train_prefix': 'y-train',
          'y_val_prefix': 'y-val',
          'y_test_prefix': 'y-test',
          'y_sample_prefix': 'y-sample',
          'ds_min': args.ds_min,
          'ds_max': args.ds_max,
          'frame_rate': FRAMERATE,
          'git_hash': utils.check_output(['git', 'describe', '--always'])
      })

  print(f'\nsaved dataset to {args.output_dir}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--input_x',
                      type=str,
                      default='data/vr_data/ST260/ST260-AllVR1.mat')
  parser.add_argument('--input_y',
                      type=str,
                      default='data/vr_data/ST260/ST260-AllVR4.mat')
  parser.add_argument('--output_dir', type=str, required=True)
  parser.add_argument('--sequence_length', type=int, required=True)
  parser.add_argument('--scale_data', action='store_true')
  parser.add_argument('--overwrite', action='store_true')
  parser.add_argument('--train_size',
                      type=int,
                      default=3000,
                      help='minimum number of training samples')
  parser.add_argument('--val_size',
                      type=int,
                      default=200,
                      help='number of validation samples')
  parser.add_argument('--test_size',
                      type=int,
                      default=200,
                      help='number of test samples')
  parser.add_argument('--sample_size',
                      type=int,
                      default=6,
                      help='number of samples allocated for plotting')
  parser.add_argument('--target_shard_size',
                      type=float,
                      default=0.25,
                      help='target size of each TFRecord file')
  parser.add_argument('--seed', type=int, default=777, help='random seed')

  main(parser.parse_args())
