import os
import math
import platform
import matplotlib
import numpy as np
import typing as t
import tensorflow as tf
import matplotlib.pyplot as plt

from cyclegan.utils import utils
from cyclegan.utils import tensorboard

plt.style.use('seaborn-deep')


def extra_attention_mask(inputs, model, gating_name: str, sigmoid_name: str):
  assert len(inputs.shape) == 4 and inputs.shape[0] == 1

  gating_layer, sigmoid_layer = None, None
  for layer in model.layers:
    if layer.name == gating_name:
      gating_layer = layer
    elif layer.name == sigmoid_name:
      sigmoid_layer = layer

  if gating_layer is None or sigmoid_layer is None:
    raise ValueError(f'layer {gating_name} and {sigmoid_name} not found.')

  attention_gate = tf.keras.Model(
      inputs=model.input,
      outputs=[gating_layer.output, sigmoid_layer.output],
      name='attention_gate')
  gating_signal, sigmoid_mask = attention_gate(inputs)
  # average pool channels in gating signal
  gating_signal = tf.reduce_mean(gating_signal, axis=-1, keepdims=True)

  return {
      'gating_signal': gating_signal.numpy(),
      'sigmoid_mask': sigmoid_mask.numpy()
  }


def transpose_data(inputs: np.ndarray, outputs: np.ndarray,
                   gates: t.Dict[int, t.Dict[str, np.ndarray]]):
  """ transpose all data to have format (num. neurons, time-steps) """
  transpose = lambda x: np.transpose(x[0, ..., 0], axes=[1, 0])
  inputs = transpose(inputs)
  outputs = transpose(outputs)
  gates = {
      i: {k: transpose(v) for k, v in gates[i].items()} for i in gates.keys()
  }
  return inputs, outputs, gates


def remove_spines(axis, remove_left: bool = True):
  axis.spines['top'].set_visible(False)
  axis.spines['bottom'].set_visible(False)
  if remove_left:
    axis.spines['left'].set_visible(False)
  axis.spines['right'].set_visible(False)


def plot_trial(args,
               trial_id: int,
               inputs: np.ndarray,
               outputs: np.ndarray,
               gates: t.Dict[int, t.Dict[str, np.ndarray]],
               info: t.Dict[str, np.ndarray],
               tag: str,
               summary,
               epoch: int = 0):
  reward_ranges = utils.get_reward_ranges(info['reward_zone'])
  assert len(reward_ranges) == 1

  cmap, aspect = 'gray', 'auto'
  figure, axes = plt.subplots(nrows=4,
                              ncols=1,
                              gridspec_kw={'hspace': 0.15},
                              figsize=(4.5, 12),
                              dpi=args.dpi)

  row = 0
  axes[row].imshow(inputs, cmap=cmap, aspect=aspect)
  axes[row, 0].tick_params(axis='both', which='both', length=0)
  axes[row].set_xlabel(f'({inputs.shape[0]}, {inputs.shape[1]})')
  axes[row].set_ylabel('Input')
  axes[row].set_title(f'Trial #{trial_id:03d}')

  # plot trial information
  tensorboard.plot_reward_zones(axes[row], reward_ranges=reward_ranges)
  tensorboard.plot_licks(axes[row], licks=np.nonzero(info['lick'])[0])
  tensorboard.plot_rewards(axes[row], rewards=np.nonzero(info['reward'])[0])

  row += 1

  for i in range(len(gates)):
    gating_signal = gates[i]['gating_signal']
    sigmoid_mask = gates[i]['sigmoid_mask']

    sigmoid_mask = sigmoid_mask / (np.max(sigmoid_mask) + 1e-6)
    heatmap = tensorboard.gray2rgb(sigmoid_mask)
    overlay = tensorboard.superimpose(heatmap=heatmap, background=gating_signal)
    axes[row].imshow(overlay, aspect=aspect)
    axes[row, 0].tick_params(axis='both', which='both', length=0)
    axes[row].set_xlabel(f'({overlay.shape[0]}, {overlay.shape[1]})')
    axes[row].set_ylabel(f'AG{i + 1}')

    row += 1

  axes[row].imshow(outputs, cmap=cmap, aspect=aspect)
  axes[row, 0].tick_params(axis='both', which='both', length=0)
  axes[row].set_xlabel(f'({outputs.shape[0]}, {outputs.shape[1]})')
  axes[row].set_ylabel('Output')

  plt.setp(axes, xticks=[], yticks=[])

  summary.figure(tag=tag, figure=figure, step=epoch, mode=1)


def plot_trials(args,
                inputs: np.ndarray,
                outputs: np.ndarray,
                gates: t.Dict[int, t.Dict[str, np.ndarray]],
                info: t.Dict[str, np.ndarray],
                tag: str,
                summary,
                epoch: int = 0,
                max_plots: int = 5):
  if args.dataset == 'horse2zebra':
    return

  inputs, outputs, gates = transpose_data(inputs, outputs, gates)

  trial_ranges = utils.get_trial_ranges(info['trial'])

  for i, (trial_start, trial_end) in enumerate(trial_ranges):
    if i >= max_plots:
      break

    trial_id = info['trial'][trial_start]

    # crop to trial range
    trial_inputs = inputs[:, trial_start:trial_end]
    trial_outputs = outputs[:, trial_start:trial_end]
    trial_info = {k: v[trial_start:trial_end] for k, v in info.items()}

    trial_gates = {}
    for i in range(len(gates)):
      gating_signal = gates[i]['gating_signal']
      sigmoid_mask = gates[i]['sigmoid_mask']
      scale = gating_signal.shape[1] / inputs.shape[1]
      gate_start = math.floor(scale * trial_start)
      gate_end = math.ceil(scale * trial_end)
      trial_gates[i] = {
          'gating_signal': gating_signal[:, gate_start:gate_end],
          'sigmoid_mask': sigmoid_mask[:, gate_start:gate_end]
      }

    plot_trial(args,
               trial_id=trial_id,
               inputs=trial_inputs,
               outputs=trial_outputs,
               gates=trial_gates,
               info=trial_info,
               tag=f'{tag}/trial_#{trial_id:03d}',
               summary=summary,
               epoch=epoch)


def plot_horse2zebra(args, inputs: np.ndarray, outputs: np.ndarray,
                     gates: t.Dict[int, t.Dict[str, np.ndarray]], tag: str,
                     summary: tensorboard.Summary, epoch: int):
  assert inputs.shape == outputs.shape and inputs.shape[0] == 1

  nrows = len(gates) + 2
  cmap, aspect, figsize, fontsize = None, 'equal', (3.5, 3.5 * nrows), 10

  inputs = ((inputs[0] + 1) * 127.5).astype(np.uint8)
  outputs = ((outputs[0] + 1) * 127.5).astype(np.uint8)
  sigmoid_masks, superimposes = [], []
  for i in range(len(gates)):
    gating_signal = gates[i]['gating_signal'][0, ..., 0]
    sigmoid_mask = gates[i]['sigmoid_mask'][0, ..., 0]
    overlay = tensorboard.superimpose(
        heatmap=tensorboard.gray2rgb(sigmoid_mask),
        background=gating_signal,
        alpha=0.8)
    superimposes.append(overlay)
    sigmoid_masks.append(sigmoid_mask)

  figure, axes = plt.subplots(nrows=nrows,
                              ncols=1,
                              gridspec_kw={
                                  'wspace': 0.01,
                                  'hspace': 0.1
                              },
                              figsize=figsize,
                              dpi=args.dpi)
  row = 0

  axes[row].imshow(inputs, cmap=cmap, aspect=aspect)
  xticks_loc = np.linspace(0, inputs.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[row],
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         fontsize=fontsize)
  yticks_loc = np.linspace(0, inputs.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[row],
                         ticks_loc=yticks_loc,
                         ticks=yticks_loc.astype(int),
                         label='Input',
                         fontsize=fontsize)
  row += 1

  for i, overlay in enumerate(superimposes):
    axes[row].imshow(overlay, aspect=aspect)
    xticks_loc = np.linspace(0, overlay.shape[1] - 1, 5)
    tensorboard.set_xticks(axis=axes[row],
                           ticks_loc=xticks_loc,
                           ticks=xticks_loc.astype(int),
                           fontsize=fontsize)
    yticks_loc = np.linspace(0, overlay.shape[0] - 1, 5)
    tensorboard.set_yticks(axis=axes[row],
                           ticks_loc=yticks_loc,
                           ticks=yticks_loc.astype(int),
                           label=f'AG{i+1}',
                           fontsize=fontsize)
    row += 1

  axes[row].imshow(outputs, cmap=cmap, aspect=aspect)
  axes[row].tick_params(axis='both', which='both', length=0)
  xticks_loc = np.linspace(0, outputs.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[row],
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         fontsize=fontsize)
  yticks_loc = np.linspace(0, outputs.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[row],
                         ticks_loc=yticks_loc,
                         ticks=yticks_loc.astype(int),
                         label='Output',
                         fontsize=fontsize)

  for i in range(nrows):
    axes[i].tick_params(axis='both', which='both', length=0)

  summary.figure(tag, figure=figure, step=epoch, mode=1)


def plot_segment(args, inputs: np.ndarray, outputs: np.ndarray,
                 gates: t.Dict[int, t.Dict[str, np.ndarray]],
                 info: t.Dict[str, np.ndarray], tag: str,
                 summary: tensorboard.Summary, epoch: int):
  assert inputs.shape == outputs.shape and inputs.shape[0] == 1

  nrows = len(gates) + 2
  cmap, aspect, figsize = 'gray', 'auto', (6, (nrows * 2) + 0.5)

  sigmoid_masks, superimposes = [], []
  inputs, outputs, gates = transpose_data(inputs, outputs, gates)
  for i in range(len(gates)):
    gating_signal = gates[i]['gating_signal']
    sigmoid_mask = gates[i]['sigmoid_mask']
    overlay = tensorboard.superimpose(
        heatmap=tensorboard.gray2rgb(sigmoid_mask), background=gating_signal)
    superimposes.append(overlay)
    sigmoid_masks.append(sigmoid_mask)

  figure, axes = plt.subplots(nrows=nrows,
                              ncols=2,
                              gridspec_kw={
                                  'width_ratios': [1, 0.05],
                                  'wspace': 0.005,
                                  'hspace': 0.2
                              },
                              figsize=figsize,
                              dpi=args.dpi)
  row = 0

  axes[row, 0].imshow(inputs, cmap=cmap, aspect=aspect)
  remove_spines(axes[row, 1])
  xticks_loc = np.linspace(0, inputs.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[row, 0],
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         label='')
  yticks_loc = np.linspace(0, inputs.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[row, 0],
                         ticks_loc=yticks_loc,
                         ticks=args.neuron_order[yticks_loc.astype(int)] + 1,
                         label='Input')
  axes[row, 0].tick_params(axis='both', which='both', length=0)
  plt.setp(axes[row, 1], xticks=[], yticks=[])
  row += 1

  for i, overlay in enumerate(superimposes):
    axes[row, 0].imshow(overlay, aspect=aspect)
    # calculate and plot attention mask intensity in y-axis
    y_intensity = np.sum(sigmoid_masks[i], axis=1)
    bins = np.array_split(y_intensity,
                          indices_or_sections=min(len(y_intensity), 12))
    width = np.array([np.sum(bin) for bin in bins])
    width /= np.max(width)
    axes[row, 1].barh(y=list(range(len(width)))[::-1],
                      width=width,
                      color='crimson',
                      alpha=0.6)
    axes[row, 0].tick_params(axis='both', which='both', length=0)
    plt.setp(axes[row, 1], xticks=[], yticks=[])
    remove_spines(axes[row, 1], remove_left=False)
    xticks_loc = np.linspace(0, overlay.shape[1] - 1, 5)
    tensorboard.set_xticks(axis=axes[row, 0],
                           ticks_loc=xticks_loc,
                           ticks=xticks_loc.astype(int),
                           label='')
    yticks_loc = np.linspace(0, overlay.shape[0] - 1, 5)
    tensorboard.set_yticks(axis=axes[row, 0],
                           ticks_loc=yticks_loc,
                           ticks=yticks_loc.astype(int) + 1,
                           label=f'AG{i+1}')
    row += 1

  axes[row, 0].imshow(outputs, cmap=cmap, aspect=aspect)
  axes[row, 0].tick_params(axis='both', which='both', length=0)
  xticks_loc = np.linspace(0, outputs.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[row, 0],
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         label='Time-step')
  yticks_loc = np.linspace(0, outputs.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[row, 0],
                         ticks_loc=yticks_loc,
                         ticks=args.neuron_order[yticks_loc.astype(int)] + 1,
                         label='Output')
  remove_spines(axes[row, 1])
  plt.setp(axes[row, 1], xticks=[], yticks=[])

  if info is not None:
    tensorboard.plot_reward_zones(axes[0, 0],
                                  reward_ranges=utils.get_reward_ranges(
                                      info['reward_zone']))

  summary.figure(tag, figure=figure, step=epoch, mode=1)


def plot_attention_gate(args, ds, gan, summary, epoch: int = 0):
  if args.model not in ['agresnet', 'agunet', 'aagresnet']:
    return None

  layer_names = [('down_2/dropout', 'up_1/attention/sigmoid'),
                 ('down_1/dropout', 'up_2/attention/sigmoid')]

  for i, (x, y) in enumerate(ds):
    x_info, y_info = None, None
    if 'reward_zone' in x:
      x_info = {
          'lick': x['lick'][0].numpy(),
          'reward': x['reward'][0].numpy(),
          'reward_zone': x['reward_zone'][0].numpy(),
          'position': x['position'][0].numpy(),
          'trial': x['trial'][0].numpy()
      }
      y_info = {
          'lick': y['lick'][0].numpy(),
          'reward': y['reward'][0].numpy(),
          'reward_zone': y['reward_zone'][0].numpy(),
          'position': y['position'][0].numpy(),
          'trial': y['trial'][0].numpy()
      }
    x, y = x['data'], y['data']

    # extra attention gate signals and sigmoid masks
    x_gates = {
        i: extra_attention_mask(inputs=x,
                                model=gan.G,
                                gating_name=gating_name,
                                sigmoid_name=sigmoid_name)
        for i, (gating_name, sigmoid_name) in enumerate(layer_names)
    }
    y_gates = {
        i: extra_attention_mask(inputs=y,
                                model=gan.F,
                                gating_name=gating_name,
                                sigmoid_name=sigmoid_name)
        for i, (gating_name, sigmoid_name) in enumerate(layer_names)
    }

    # plot G(x) attention gates
    x, fake_y = x.numpy(), gan.G(x).numpy()
    if args.dataset == 'horse2zebra':
      plot_horse2zebra(args,
                       inputs=x,
                       outputs=fake_y,
                       gates=x_gates,
                       tag=f'AG_G(x)/trial_{i:03d}',
                       summary=summary,
                       epoch=epoch)
    else:
      plot_segment(args,
                   inputs=x,
                   outputs=fake_y,
                   gates=x_gates,
                   info=x_info,
                   tag=f'AG_G(x)/trial_{i:03d}',
                   summary=summary,
                   epoch=epoch)

    # plot F(y) attention gates
    y, fake_x = y.numpy(), gan.F(y).numpy()
    if args.dataset == 'horse2zebra':
      plot_horse2zebra(args,
                       inputs=y,
                       outputs=fake_x,
                       gates=y_gates,
                       tag=f'AG_F(y)/trial_{i:03d}',
                       summary=summary,
                       epoch=epoch)
    else:
      plot_segment(args,
                   inputs=y,
                   outputs=fake_x,
                   gates=y_gates,
                   info=y_info,
                   tag=f'AG_F(y)/trial_{i:03d}',
                   summary=summary,
                   epoch=epoch)
