import os
import typing as t
import numpy as np
from math import ceil
import tensorflow as tf
import tensorflow_datasets as tfds

from cyclegan.utils import utils

AUTOTUNE = tf.data.AUTOTUNE


def get_horse2zebra(args, buffer_size=256):
  data, metadata = tfds.load('cycle_gan/horse2zebra',
                             as_supervised=True,
                             with_info=True)

  get_size = lambda name: metadata.splits.get(name).num_examples

  num_train_samples = min([get_size('trainA'), get_size('trainB')])
  num_val_samples = min([get_size('testA'), get_size('testB')])

  args.train_steps = ceil(num_train_samples / args.global_batch_size)
  args.val_steps = ceil(num_val_samples / args.global_batch_size)
  args.test_steps = args.val_steps

  x_train, x_val = data['trainA'], data['testA']
  y_train, y_val = data['trainB'], data['testB']

  def scale(image):
    image = tf.cast(image, dtype=tf.float32)
    return (image / 127.5) - 1

  def preprocess_train(image, label):
    image = tf.image.random_flip_left_right(image)
    image = tf.image.resize(image, size=(286, 286))
    image = tf.image.random_crop(image, size=(256, 256, 3))
    image = scale(image)
    return {'data': image}

  def preprocess_val(image, label):
    image = tf.image.resize(image, size=(256, 256))
    image = scale(image)
    return {'data': image}

  x_train = x_train.take(num_train_samples)
  x_train = x_train.map(preprocess_train, num_parallel_calls=AUTOTUNE)
  x_train = x_train.cache()
  x_train = x_train.shuffle(buffer_size)
  x_train = x_train.batch(args.global_batch_size, drop_remainder=True)

  y_train = y_train.take(num_train_samples)
  y_train = y_train.map(preprocess_train, num_parallel_calls=AUTOTUNE)
  y_train = y_train.cache()
  y_train = y_train.shuffle(buffer_size)
  y_train = y_train.batch(args.global_batch_size, drop_remainder=True)

  x_val = x_val.take(num_val_samples)
  x_val = x_val.map(preprocess_val, num_parallel_calls=AUTOTUNE)
  x_sample = x_val.take(5).batch(1)
  x_val = x_val.batch(args.global_batch_size, drop_remainder=True)

  y_val = y_val.take(num_val_samples)
  y_val = y_val.map(preprocess_val, num_parallel_calls=AUTOTUNE)
  y_sample = y_val.take(5).batch(1)
  y_val = y_val.batch(args.global_batch_size, drop_remainder=True)

  args.scaled_data = False
  args.ds_min, args.ds_max = 0.0, 255.0
  args.input_shape = (256, 256, 3)

  train_ds = tf.data.Dataset.zip((x_train, y_train)).prefetch(AUTOTUNE)
  val_ds = tf.data.Dataset.zip((x_val, y_val))
  sample_ds = tf.data.Dataset.zip((x_sample, y_sample))

  return train_ds, val_ds, val_ds, sample_ds


def load_neuron_order(args):
  """ load neuron order from args.order_neuron if specified """
  if args.neuron_order_json is None:
    args.neuron_order = np.arange(args.num_neurons, dtype=np.int16)
  else:
    assert os.path.exists(args.neuron_order_json), \
      f'{args.neuron_order_json} not found'
    content = utils.load_json(args.neuron_order_json)
    args.neuron_order = np.array(content['order'], dtype=np.int16)
    print(f'\nloaded neuron orders from {args.neuron_order_json}')


def get_info(args):
  """ Get dataset information from info.json """
  if not os.path.exists(args.dataset):
    raise FileNotFoundError(f'{args.dataset} not found')

  info = utils.load_json(os.path.join(args.dataset, 'info.json'))

  # input shape has shape (WC)
  args.input_shape = tuple(info['signal_shape'])
  args.num_neurons = info['num_neurons']
  args.scaled_data = info['scaled_data']
  args.frame_rate = info['frame_rate']

  args.ds_min, args.ds_max = float(info['ds_min']), float(info['ds_max'])

  args.train_steps = ceil(min(info['train_sizes']) / args.global_batch_size)
  args.val_steps = ceil(min(info['val_sizes']) / args.global_batch_size)
  args.test_steps = min(info['test_sizes'])

  load_neuron_order(args)

  get_files = lambda prefix: os.path.join(args.dataset, f'{prefix}*.record')

  return {
      'x_train': get_files(info['x_train_prefix']),
      'x_val': get_files(info['x_val_prefix']),
      'x_test': get_files(info['x_test_prefix']),
      'x_sample': get_files(info['x_sample_prefix']),
      'y_train': get_files(info['y_train_prefix']),
      'y_val': get_files(info['y_val_prefix']),
      'y_test': get_files(info['y_test_prefix']),
      'y_sample': get_files(info['y_sample_prefix'])
  }


def get_diagonal_mask(input_shape: t.Tuple[int, int], num_shifts: int = 8):
  ''' return shifting mask for data augmentation '''
  mask = np.zeros(input_shape)
  for n in range(1, input_shape[1]):
    shift = n * num_shifts
    x, y = list(range(shift)), [n] * shift
    mask[x, y] = 1
  return mask.astype(bool)


def augment(signal, diagonal_mask: np.ndarray, ds_min: float, ds_max: float):
  is_tensor = tf.is_tensor(signal)
  if is_tensor:
    signal = signal.numpy()
  signal = utils.unscale(signal, ds_min=ds_min, ds_max=ds_max)
  augmented = np.copy(signal)
  # set masked area to zero
  augmented[diagonal_mask] = 0.0
  # add Gaussian noise
  mu, sigma = np.mean(signal, axis=0), np.std(signal, axis=0)
  noise = np.random.normal(mu, sigma, size=signal.shape)
  augmented += 0.25 * noise.astype(np.float32)
  augmented = utils.scale(augmented, ds_min=ds_min, ds_max=ds_max)
  if is_tensor:
    augmented = tf.convert_to_tensor(augmented)
  return augmented


def get_dataset(args,
                train_files: str,
                val_files: str,
                test_files: str,
                sample_files: str,
                neuron_order: np.ndarray = None,
                diagonal_mask: np.ndarray = None,
                buffer_size: int = 512,
                repeat: bool = False):

  def parse_sample(example):
    example = tf.io.parse_single_example(
        example,
        features={
            'signal': tf.io.FixedLenFeature([], tf.string),
            'reward': tf.io.FixedLenFeature([], tf.string),
            'lick': tf.io.FixedLenFeature([], tf.string),
            'reward_zone': tf.io.FixedLenFeature([], tf.string),
            'position': tf.io.FixedLenFeature([], tf.string),
            'trial': tf.io.FixedLenFeature([], tf.string)
        })

    signal = tf.io.decode_raw(example['signal'], out_type=tf.float32)
    signal = tf.reshape(signal, shape=args.input_shape)

    if neuron_order is not None:
      signal = tf.py_function(utils.order_neuron,
                              inp=[signal, neuron_order],
                              Tout=tf.float32)

    if diagonal_mask is not None:
      signal = tf.py_function(
          augment,
          inp=[signal, diagonal_mask, args.ds_min, args.ds_max],
          Tout=tf.float32)

    if not args.input_2d:
      signal = tf.expand_dims(signal, axis=-1)

    # parse trial information
    shape = (args.input_shape[0],)
    reward = tf.io.decode_raw(example['reward'], out_type=tf.int16)
    reward = tf.reshape(reward, shape=shape)
    lick = tf.io.decode_raw(example['lick'], out_type=tf.int16)
    lick = tf.reshape(lick, shape=shape)
    reward_zone = tf.io.decode_raw(example['reward_zone'], out_type=tf.int16)
    reward_zone = tf.reshape(reward_zone, shape=shape)
    position = tf.io.decode_raw(example['position'], out_type=tf.float32)
    position = tf.reshape(position, shape=shape)
    trial = tf.io.decode_raw(example['trial'], out_type=tf.int16)
    trial = tf.reshape(trial, shape=shape)

    return {
        'data': signal,
        'reward': reward,
        'lick': lick,
        'reward_zone': reward_zone,
        'position': position,
        'trial': trial
    }

  train_ds = tf.data.Dataset.list_files(train_files)
  train_ds = train_ds.interleave(tf.data.TFRecordDataset,
                                 num_parallel_calls=AUTOTUNE)
  train_ds = train_ds.map(parse_sample, num_parallel_calls=AUTOTUNE)
  train_ds = train_ds.cache()
  train_ds = train_ds.shuffle(buffer_size)
  if repeat:
    train_ds = train_ds.repeat()
  train_ds = train_ds.batch(args.global_batch_size, drop_remainder=True)

  val_files = tf.data.Dataset.list_files(val_files)
  val_ds = val_files.interleave(tf.data.TFRecordDataset,
                                num_parallel_calls=AUTOTUNE)
  val_ds = val_ds.map(parse_sample, num_parallel_calls=AUTOTUNE)
  val_ds = val_ds.cache()
  val_ds = val_ds.batch(args.global_batch_size, drop_remainder=True)

  test_files = tf.data.Dataset.list_files(test_files)
  test_ds = test_files.interleave(tf.data.TFRecordDataset)
  test_ds = test_ds.map(parse_sample)
  test_ds = test_ds.cache()
  test_ds = test_ds.batch(1)

  sample_files = tf.data.Dataset.list_files(sample_files)
  sample_ds = sample_files.interleave(tf.data.TFRecordDataset)
  sample_ds = sample_ds.map(parse_sample)
  sample_ds = sample_ds.batch(1)

  return train_ds, val_ds, test_ds, sample_ds


def dry_run_datasets(args, train_ds, val_ds, test_ds, sample_ds):
  if args.verbose:
    print('\ndry run tf.data.Dataset\n')
  count = 0
  for ds in [train_ds, val_ds, test_ds, sample_ds]:
    for _ in ds:
      count += 1


def get_datasets(args, dry_run: bool = True):
  if args.dataset == 'horse2zebra':
    train_ds, val_ds, test_ds, sample_ds = get_horse2zebra(args)
  else:
    filenames = get_info(args)

    neuron_order = None if args.neuron_order_json is None else args.neuron_order

    x_train, x_val, x_test, x_sample = get_dataset(
        args,
        train_files=filenames['x_train'],
        val_files=filenames['x_val'],
        test_files=filenames['x_test'],
        sample_files=filenames['x_sample'],
        neuron_order=neuron_order)

    if args.synthetic_data:
      y_train, y_val, y_test, y_sample = get_dataset(
          args,
          train_files=filenames['x_train'],
          val_files=filenames['x_val'],
          test_files=filenames['x_test'],
          sample_files=filenames['x_sample'],
          neuron_order=neuron_order,
          diagonal_mask=get_diagonal_mask(args.input_shape,
                                          num_shifts=args.synthetic_shift))
    else:
      y_train, y_val, y_test, y_sample = get_dataset(
          args,
          train_files=filenames['y_train'],
          val_files=filenames['y_val'],
          test_files=filenames['y_test'],
          sample_files=filenames['y_sample'],
          neuron_order=neuron_order,
          diagonal_mask=None)

    if not args.input_2d:
      args.input_shape = args.input_shape + (1,)

    train_ds = tf.data.Dataset.zip((x_train, y_train)).prefetch(AUTOTUNE)
    val_ds = tf.data.Dataset.zip((x_val, y_val))
    test_ds = tf.data.Dataset.zip((x_test, y_test))
    sample_ds = tf.data.Dataset.zip((x_sample, y_sample))

  if dry_run:
    dry_run_datasets(args,
                     train_ds=train_ds,
                     val_ds=val_ds,
                     test_ds=test_ds,
                     sample_ds=sample_ds)

  return train_ds, val_ds, test_ds, sample_ds
