import os
import argparse
import scipy.io
import numpy as np
import tensorflow as tf
from pathlib import Path
from shutil import rmtree

from cyclegan.utils import utils

# use CPU only
os.environ["CUDA_VISIBLE_DEVICES"] = ""

np.random.seed(1234)

# frame rate of the calcium imaging
FRAMERATE = 32.6825
# take segment between 1st and 5th second i.e. 35th and 163rd frame
START, STOP = 35, 163


def concat_dict(inputs):
  outputs = []
  for contrast in inputs.keys():
    for angle in inputs[contrast].keys():
      outputs.append(inputs[contrast][angle])
  outputs = np.concatenate(outputs, axis=0)
  return outputs.astype(np.float32)


def get_day1_df(filename,
                contrasts=[1., 0.5, 0.25, 0.125],
                angles=[0, 30, 60, 90, 120, 150]):
  """ extract and return calcium traces recorded from day 1 in a dictionary
  traces have the format of (trials, neurons, time-steps)
  """
  print(f'load {filename}...')
  data = scipy.io.loadmat(filename)
  num_neurons = data['df_block'][0][0].shape[0]

  num_traces, traces = 0, {}
  for i, contrast in enumerate(contrasts):
    for j, angle in enumerate(angles):
      trial_mask = (data['contrast'] == contrast) & (data['orient'] == angle)
      if np.any(trial_mask):
        trace = np.zeros((np.sum(trial_mask), num_neurons, STOP - START))
        for trial, (x, y) in enumerate(zip(*np.where(trial_mask))):
          trace[trial, :] = data['df_block'][x, y][:num_neurons, START:STOP]
        if contrast not in traces:
          traces[contrast] = {}
        traces[contrast][angle] = trace
        num_traces += 1
      else:
        print(f'\tNo trace found with contrast {contrast} and angle {angle}')
  print(f'\textracted {num_traces} traces from {filename}...')
  return traces


def get_day2_df(filename, contrasts=[1., 0.5, 0.25, 0.125], angles=[0, 30]):
  """ extract and return calcium traces recorded from day 2 in a dictionary
  traces have the format of (trials, neurons, time-steps)
  """
  print(f'load {filename}...')
  data = scipy.io.loadmat(filename)
  # data['final_deltaf'] is in format (num. neurons, time-step, trial)
  num_neurons = data['final_deltaf'].shape[0]
  num_traces, traces = 0, {}
  for i, contrast in enumerate(contrasts):
    for j, angle in enumerate(angles):
      trial_mask = np.squeeze((data['orientations'] == angle) &
                              (data['contrast_orientation'] == contrast))
      if np.any(trial_mask):
        trace = data['final_deltaf'][:num_neurons, START:STOP, trial_mask]
        trace = np.transpose(trace, axes=[2, 0, 1])
        if contrast not in traces:
          traces[contrast] = {}
        traces[contrast][angle] = trace
        num_traces += 1
      else:
        print(f'\tNo trace found with contrast {contrast} and angle {angle}')
  print(f'\textracted {num_traces} traces from {filename}...')
  return traces


def _bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


def serialize_example(signal):
  features = {'signal': _bytes_feature(signal.tobytes())}
  example = tf.train.Example(features=tf.train.Features(feature=features))
  return example.SerializeToString()


def write_to_record(args, signals, filename):
  filename = str(args.output_dir / filename)
  print(f'writing {len(signals)} segments to {filename}...')
  with tf.io.TFRecordWriter(filename) as writer:
    for i in range(len(signals)):
      example = serialize_example(signals[i])
      writer.write(example)


def main(args):
  args.output_dir = Path(args.output_dir)
  if args.output_dir.exists() and args.overwrite:
    rmtree(args.output_dir)
  args.output_dir.mkdir(parents=True)

  # read from matlab files
  ds1 = get_day1_df(args.input_x, angles=[0, 30])
  ds2 = get_day2_df(args.input_y, angles=[0, 30])

  # concatenate all trials
  ds1 = concat_dict(ds1)
  ds2 = concat_dict(ds2)

  # convert dataset to format (trials, time-steps, neurons)
  ds1 = np.transpose(ds1, axes=[0, 2, 1])
  ds2 = np.transpose(ds2, axes=[0, 2, 1])

  # shuffle datasets
  np.random.shuffle(ds1)
  np.random.shuffle(ds2)

  # ensure both datasets have the same number of neurons
  # args.num_neurons = min(ds1.shape[2], ds2.shape[2])
  args.num_neurons = 88
  ds1 = ds1[:, :, :args.num_neurons]
  ds2 = ds2[:, :, :args.num_neurons]

  assert ds1.shape[1:] == ds2.shape[1:]

  args.signal_shape = ds1.shape[1:]
  args.sequence_length = ds1.shape[1]

  # scale datasets to [0, 1] if specified
  print(f'ds1 min {np.min(ds1):.04f}\t'
        f'max {np.max(ds1):.04f}\t'
        f'mean {np.mean(ds1):.04f}\n'
        f'ds2 min {np.min(ds2):.04f}\t'
        f'max {np.max(ds2):.04f}\t'
        f'mean {np.mean(ds2):.04f}')
  args.ds_min = min(np.min(ds1), np.min(ds2))
  args.ds_max = max(np.max(ds1), np.max(ds2))
  if args.scale_data:
    print('scale datasets...')
    ds1 = utils.scale(ds1, args.ds_min, args.ds_max)
    ds2 = utils.scale(ds2, args.ds_min, args.ds_max)
    print(f'ds1 min {np.min(ds1):.04f}\t'
          f'max {np.max(ds1):.04f}\t'
          f'mean {np.mean(ds1):.04f}\n'
          f'ds2 min {np.min(ds2):.04f}\t'
          f'max {np.max(ds2):.04f}\t'
          f'mean {np.mean(ds2):.04f}')

  assert len(ds1) > args.validation_size and len(ds2) > args.validation_size

  validation_sizes = [args.validation_size, args.validation_size]
  train_sizes = [len(ds1) - validation_sizes[0], len(ds2) - validation_sizes[1]]

  # write dataset to TFRecords
  write_to_record(args, ds1[:train_sizes[0]], filename='x-train.record')
  write_to_record(args, ds1[train_sizes[0]:], filename='x-val.record')
  write_to_record(args, ds2[:train_sizes[1]], filename='y-train.record')
  write_to_record(args, ds2[train_sizes[1]:], filename='y-val.record')

  # save information of the dataset
  utils.save_json(args.output_dir / 'info.json',
                  data={
                      'train_sizes': train_sizes,
                      'validation_sizes': validation_sizes,
                      'signal_shape': args.signal_shape,
                      'sequence_length': args.sequence_length,
                      'num_neurons': args.num_neurons,
                      'scaled_data': args.scale_data,
                      'x_train_prefix': 'x-train',
                      'x_validation_prefix': 'x-val',
                      'y_train_prefix': 'y-train',
                      'y_validation_prefix': 'y-val',
                      'ds_min': args.ds_min,
                      'ds_max': args.ds_max,
                      'frame_rate': FRAMERATE
                  })

  print(f'\nsaved dataset to {args.output_dir}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--input_x',
                      type=str,
                      default='data/passive_ds/20200723_Thy171_alldata.mat')
  parser.add_argument('--input_y',
                      type=str,
                      default='data/passive_ds/20200724_Thy171_alldata.mat')
  parser.add_argument('--output_dir', type=str, default='tfrecords/passive')
  parser.add_argument('--validation_size', type=int, default=20)
  parser.add_argument('--overwrite', action='store_true')
  parser.add_argument('--scale_data', action='store_true')
  parser.add_argument('--verbose', default=1, type=int)
  main(parser.parse_args())
