import os
import io
import sys
import copy
import json
import pickle
import subprocess
import numpy as np
import typing as t
import tensorflow as tf

from cyclegan.utils.gradcam import gradcam
from cyclegan.utils import h5, attention_gate


def split_index(length, n):
  """ return a list of (start, end) that divide length into n chunks """
  k, m = divmod(length, n)
  return [(i * k + min(i, m), (i + 1) * k + min(i + 1, m)) for i in range(n)]


def split(sequence, n):
  """ divide sequence into n sub-sequences evenly """
  indexes = split_index(len(sequence), n)
  return [sequence[indexes[i][0]:indexes[i][1]] for i in range(len(indexes))]


def scale(x, ds_min: float, ds_max: float):
  """ scale x to be between 0 and 1 """
  return (x - ds_min) / (ds_max - ds_min)


def unscale(x, ds_min: float, ds_max: float):
  """ re-scale signals back to its original range """
  return x * (ds_max - ds_min) + ds_min


def order_neuron(x, order: np.ndarray):
  """
  order neurons in segment x
  Note: x must be in shape (..., H, W, 1) or (..., H, W)
  """
  dim = -2 if x.shape[-1] == 1 else -1
  assert x.shape[dim] == len(order)
  is_tensor = tf.is_tensor(x)
  if is_tensor:
    x = x.numpy()
  x = x[..., order] if dim == -1 else x[..., order, :]
  if is_tensor:
    x = tf.convert_to_tensor(x, tf.float32)
  return x


def postprocess(signals,
                input_2d: bool = False,
                ds_min: float = None,
                ds_max: float = None,
                reverse_order: np.ndarray = None):
  """ post-processing calcium signals """
  if tf.is_tensor(signals):
    signals = signals.numpy()
  if not input_2d:
    signals = np.squeeze(signals, axis=-1)
  if ds_min is not None and ds_max is not None:
    signals = unscale(signals, ds_min=ds_min, ds_max=ds_max)
  if reverse_order is not None:
    signals = order_neuron(signals, order=reverse_order)
  return signals


def postprocess_cycle(args, samples: t.Dict[str, np.ndarray]):
  """ process x, y, fake_x, fake_y, cycle_x, cycle_y samples """

  samples['x'] = postprocess(samples['x'],
                             input_2d=args.input_2d,
                             ds_min=args.ds_min,
                             ds_max=args.ds_max)
  samples['y'] = postprocess(samples['y'],
                             input_2d=args.input_2d,
                             ds_min=args.ds_min,
                             ds_max=args.ds_max)
  samples['fake_x'] = postprocess(samples['fake_x'],
                                  input_2d=args.input_2d,
                                  ds_min=args.ds_min,
                                  ds_max=args.ds_max)
  samples['fake_y'] = postprocess(samples['fake_y'],
                                  input_2d=args.input_2d,
                                  ds_min=args.ds_min,
                                  ds_max=args.ds_max)
  samples['cycle_x'] = postprocess(samples['cycle_x'],
                                   input_2d=args.input_2d,
                                   ds_min=args.ds_min,
                                   ds_max=args.ds_max)
  samples['cycle_y'] = postprocess(samples['cycle_y'],
                                   input_2d=args.input_2d,
                                   ds_min=args.ds_min,
                                   ds_max=args.ds_max)
  samples['same_x'] = postprocess(samples['same_x'],
                                  input_2d=args.input_2d,
                                  ds_min=args.ds_min,
                                  ds_max=args.ds_max)
  samples['same_y'] = postprocess(samples['same_y'],
                                  input_2d=args.input_2d,
                                  ds_min=args.ds_min,
                                  ds_max=args.ds_max)


def update_dict(target, source, replace=False):
  """ add or update items in source to target """
  for key, value in source.items():
    if replace:
      target[key] = value
    else:
      if key not in target:
        target[key] = []
      target[key].append(value)


def update_json(filename, data):
  content = {}
  if os.path.exists(filename):
    content = load_json(filename)
  for key, value in data.items():
    content[key] = value
  save_json(filename, content)


def load_json(filename):
  with open(filename, 'r') as file:
    content = json.load(file)
  return content


def save_json(filename: str, data: t.Dict):
  assert type(data) == dict
  for key, value in data.items():
    if isinstance(value, np.ndarray):
      data[key] = value.tolist()
    elif isinstance(value, np.float32):
      data[key] = float(value)
  with open(filename, 'w') as file:
    json.dump(data, file)


def check_output(command: list):
  """ run command in subprocess and return output as string """
  return subprocess.check_output(command).strip().decode()


def save_args(args):
  """ save args object as dictionary to output_dir/args.json """
  arguments = copy.deepcopy(args.__dict__)
  arguments['git_hash'] = check_output(['git', 'describe', '--always'])
  arguments['hostname'] = check_output(['hostname'])
  save_json(filename=os.path.join(args.output_dir, 'args.json'), data=arguments)


def load_args(args):
  content = load_json(os.path.join(args.output_dir, 'args.json'))
  for key, value in content.items():
    if not hasattr(args, key):
      setattr(args, key, value)
      if key == 'neuron_order':
        args.neuron_order = np.array(args.neuron_order, dtype=int)


def count_trainable_params(model):
  ''' return the number of trainable parameters'''
  return np.sum(
      [tf.keras.backend.count_params(p) for p in model.trainable_variables])


def model_summary(args, model):
  ''' get tf.keras model summary as a string and save it as txt '''
  stream = io.StringIO()
  model.summary(print_fn=lambda x: stream.write(x + '\n'))
  summary = stream.getvalue()
  stream.close()
  with open(os.path.join(args.output_dir, f'{model.name}.txt'), 'a') as file:
    file.write(summary)
  return summary


def remove_nan(array):
  return array[np.logical_not(np.isnan(array))]


def inference_dataset(ds, gan):
  ''' return cycle step samples from ds '''
  samples = {}
  for x, y in ds:
    if 'reward_zone' in x:
      update_dict(
          samples, {
              'x_lick': x['lick'],
              'y_lick': y['lick'],
              'x_reward': x['reward'],
              'y_reward': y['reward'],
              'x_reward_zone': x['reward_zone'],
              'y_reward_zone': y['reward_zone']
          })
    x, y = x['data'], y['data']

    fake_x, fake_y, cycle_x, cycle_y = gan.cycle_step(x, y, training=False)
    same_x, same_y = gan.F(x, training=False), gan.G(y, training=False)
    update_dict(
        samples, {
            'x': x,
            'y': y,
            'fake_x': fake_x,
            'fake_y': fake_y,
            'cycle_x': cycle_x,
            'cycle_y': cycle_y,
            'same_x': same_x,
            'same_y': same_y
        })

  return {k: tf.concat(v, axis=0).numpy() for k, v in samples.items()}


def save_samples(args, gan, ds):
  samples = inference_dataset(ds, gan)
  postprocess_cycle(args, samples=samples)

  if not os.path.exists(args.samples_dir):
    os.makedirs(args.samples_dir)
  args.signals_filename = os.path.join(args.samples_dir, 'signals.h5')

  if os.path.exists(args.signals_filename):
    os.remove(args.signals_filename)

  h5.write(args.signals_filename, samples)

  print(f'saved signal samples to {args.signals_filename}\n')


def augmentation(signals):
  # convert signals to shape (NCW)
  signals = tf.transpose(signals, perm=(0, 2, 1))
  fft = tf.signal.fft(tf.cast(signals, dtype=tf.complex64))
  noise = tf.expand_dims(tf.random.normal(fft.shape[1:]), axis=0)
  mask = tf.random.uniform((noise.shape[0], 1, 1), maxval=0.6)
  fft = fft + tf.cast(mask * noise, dtype=tf.complex64)
  ifft = tf.signal.ifft(fft)
  signals = tf.transpose(tf.math.real(ifft), perm=(0, 2, 1))
  return signals


def plot_augmentation(args, samples, summary, epoch):
  x, y = samples
  augmented_x, augmented_y = augmentation(x), augmentation(y)
  for i in range(min(2, len(x))):
    summary.plot_comparison(f'augment_x/sample_#{i:02d}',
                            traces=[x[i], augmented_x[i]],
                            labels=['x', 'augmented x'],
                            step=epoch,
                            training=True)
    summary.plot_comparison(f'augment_y/sample_#{i:02d}',
                            traces=[y[i], augmented_y[i]],
                            labels=['y', 'augmented y'],
                            step=epoch,
                            training=True)


def layer_exists(model, layer_name: str):
  """ return True if layer with layer_name exists in model """
  for layer in model.layers:
    if layer.name == layer_name:
      return True
  return False


def get_reward_ranges(
    reward_zone: t.Union[tf.Tensor, np.ndarray]) -> t.List[t.Tuple[int, int]]:
  ''' return list of tuple of reward zone ranges '''
  if tf.is_tensor(reward_zone):
    reward_zone = reward_zone.numpy()

  ranges, start, end = [], None, None
  for i in range(1, len(reward_zone)):
    if reward_zone[i - 1] == 0 and reward_zone[i] == 1:
      start = i
    elif reward_zone[i - 1] == 1 and reward_zone[i] == 0:
      end = i
    if start is not None and end is not None:
      ranges.append((start, end))
      start, end = None, None

  return ranges


def get_trial_ranges(trial: t.Union[tf.Tensor, np.ndarray]):
  ''' return trial start and end ranges '''
  if tf.is_tensor(trial):
    trial = trial.numpy()

  ranges = []
  start = None
  for i in range(1, len(trial)):
    if trial[i] != trial[i - 1]:
      if start is not None:
        ranges.append((start, i))
      start = i

  return ranges


def get_num_gpus():
  ''' return the number of available GPUs '''
  return len(tf.config.list_physical_devices(device_type='GPU'))


def plot_cycle(args, ds, gan, summary, epoch: int = 0):
  samples = inference_dataset(ds, gan)
  if args.dataset == 'horse2zebra':
    samples = {
        k: ((v + 1) * 127.5).astype(np.uint8) for k, v in samples.items()
    }
    summary.image_cycle(
        f'X_cycle',
        images=[samples['x'], samples['fake_y'], samples['cycle_x']],
        labels=['X', 'G(X)', 'F(G(X))'],
        step=epoch,
        mode=1)
    summary.image_cycle(
        f'Y_cycle',
        images=[samples['y'], samples['fake_x'], samples['cycle_y']],
        labels=['Y', 'F(Y)', 'G(F(Y))'],
        step=epoch,
        mode=1)
  else:
    postprocess_cycle(args, samples=samples)
    summary.population_cycle(
        tag=f'X_population_cycle',
        samples=[samples['x'], samples['fake_y'], samples['cycle_x']],
        titles=['x', 'G(x)', 'F(G(x))'],
        info={
            'reward': samples['x_reward'],
            'lick': samples['x_lick'],
            'reward_zone': samples['x_reward_zone']
        },
        step=epoch,
        mode=1)
    summary.population_cycle(
        tag=f'Y_population_cycle',
        samples=[samples['y'], samples['fake_x'], samples['cycle_y']],
        titles=['y', 'F(y)', 'G(F(y))'],
        info={
            'reward': samples['y_reward'],
            'lick': samples['y_lick'],
            'reward_zone': samples['y_reward_zone']
        },
        step=epoch,
        mode=1)
    colors = ['dodgerblue', 'orangered']
    summary.trace_cycle(
        tag=f'X_cycle',
        samples=[samples['x'], samples['fake_y'], samples['cycle_x']],
        titles=['x', 'G(x)', 'F(G(x))'],
        colors=colors,
        step=epoch,
        mode=1)
    summary.trace_cycle(
        tag=f'Y_cycle',
        samples=[samples['y'], samples['fake_x'], samples['cycle_y']],
        titles=['y', 'F(y)', 'G(F(y))'],
        colors=colors[::-1],
        step=epoch,
        mode=1)


def plot(args, test_ds, sample_ds, gan, summary, epoch: int):
  ''' plot a set of figures for evaluation '''
  plot_cycle(args, ds=sample_ds, gan=gan, summary=summary, epoch=epoch)
  gradcam.plot_gradcam(args,
                       ds=sample_ds,
                       gan=gan,
                       epoch=epoch,
                       summary=summary)
  gradcam.plot_positional_gradcam(args,
                                  ds=test_ds,
                                  gan=gan,
                                  summary=summary,
                                  epoch=epoch,
                                  ds_size=args.test_steps)
  attention_gate.plot_attention_gate(args,
                                     ds=sample_ds,
                                     gan=gan,
                                     summary=summary,
                                     epoch=epoch)
