import numpy as np
import typing as t
from tqdm import tqdm
import tensorflow as tf

from cyclegan.utils.gradcam import plot
from cyclegan.utils import utils, tensorboard


def append_results(results: t.Dict[str, t.List[np.ndarray]],
                   gradcam: t.Dict[str, np.ndarray],
                   trial_info: t.Dict[str, tf.Tensor] = None,
                   swap_axes: bool = False):
  ''' add gradcam and trial information to results dictionary '''
  swapaxes = lambda x: np.swapaxes(x, axis1=0, axis2=1)

  if 'inputs' not in results:
    results['inputs'] = []
  results['inputs'].append(
      swapaxes(gradcam['inputs']) if swap_axes else gradcam['inputs'])
  if 'cam' not in results:
    results['cam'] = []
  results['cam'].append(
      swapaxes(gradcam['cam']) if swap_axes else gradcam['cam'])
  if 'prediction' not in results:
    results['prediction'] = []
  results['prediction'].append(
      gradcam['prediction'] if swap_axes else gradcam['prediction'])
  if trial_info is not None:
    if 'licks' not in results:
      results['licks'] = []
    results['licks'].append(trial_info['lick'][0].numpy())
    if 'rewards' not in results:
      results['rewards'] = []
    results['rewards'].append(trial_info['reward'][0].numpy())
    if 'reward_zones' not in results:
      results['reward_zones'] = []
    results['reward_zones'].append(trial_info['reward_zone'][0].numpy())
    if 'positions' not in results:
      results['positions'] = []
    results['positions'].append(trial_info['position'][0].numpy())
    if 'trials' not in results:
      results['trials'] = []
    results['trials'].append(trial_info['trial'][0].numpy())


def list2array(d: t.Dict[str, t.List[np.ndarray]]):
  ''' convert dictionary of list of np.ndarray to dictionary of np.ndarray '''
  output = {}
  for k, v in d.items():
    if v:
      output[k] = np.stack(v)
  return output


def GradCAM(inputs, model, layer_name: str) -> t.Dict[str, np.ndarray]:
  """ Compute the Grad-CAM heatmap of layer_name with respect to image
  Two new models are constructed automatically
  - top: model that connects model.input to layer_name
  - bottom: model that takes layer_name as input and return the model output
  """
  assert len(inputs.shape) == 4 and inputs.shape[0] == 1, \
    f'inputs should have shape (1, H, W, C)'

  layer = model.get_layer(layer_name)
  grad_model = tf.keras.Model(inputs=model.input,
                              outputs=[layer.output, model.output],
                              name='grad_model')

  with tf.GradientTape() as tape:
    layer_output, model_output = grad_model(inputs)

  gradient = tape.gradient(target=model_output, sources=layer_output)
  pooled_gradient = tf.reduce_mean(gradient,
                                   axis=list(range(gradient.ndim - 1)))

  cam = layer_output[0] @ pooled_gradient[..., tf.newaxis]
  cam = tf.squeeze(cam)
  cam = tf.maximum(cam, 0) / tf.reduce_max(cam)
  cam = cam.numpy()

  if cam.shape != inputs.shape[1:3]:
    cam = tensorboard.resize(cam, height=inputs.shape[1], width=inputs.shape[2])

  return {
      'inputs': inputs[0].numpy(),
      'cam': cam,
      'prediction': np.mean(model_output.numpy())
  }


def get_models_gradcam(args,
                       ds,
                       gan,
                       g_layer_name: str,
                       d_layer_name: str,
                       ds_size: int = None):
  '''
  Get GradCAM activations maps for generators with g_layer_name
  and discriminators with d_layer_name
  '''
  G_results, F_results, X_results, Y_results = {}, {}, {}, {}

  for x, y in tqdm(ds,
                   desc='GradCAM',
                   total=ds_size,
                   disable=args.verbose == 0 or ds_size is None):
    x_info = x if 'reward_zone' in x else None
    y_info = y if 'reward_zone' in y else None
    swap_axes = args.dataset != 'horse2zebra'

    generate_x = GradCAM(
        inputs=x['data'],
        model=gan.G,
        layer_name=g_layer_name,
    )
    append_results(G_results,
                   gradcam=generate_x,
                   trial_info=x_info,
                   swap_axes=swap_axes)

    generate_y = GradCAM(
        inputs=y['data'],
        model=gan.F,
        layer_name=g_layer_name,
    )
    append_results(F_results,
                   gradcam=generate_y,
                   trial_info=y_info,
                   swap_axes=swap_axes)

    discriminate_x = GradCAM(inputs=x['data'],
                             model=gan.X,
                             layer_name=d_layer_name)
    append_results(X_results,
                   gradcam=discriminate_x,
                   trial_info=x_info,
                   swap_axes=swap_axes)

    discriminate_y = GradCAM(inputs=y['data'],
                             model=gan.Y,
                             layer_name=d_layer_name)
    append_results(Y_results,
                   gradcam=discriminate_y,
                   trial_info=y_info,
                   swap_axes=swap_axes)

  G_results = list2array(G_results)
  F_results = list2array(F_results)
  X_results = list2array(X_results)
  Y_results = list2array(Y_results)

  return G_results, F_results, X_results, Y_results


def compute_positional_attention(args, results: t.Dict[str, np.ndarray],
                                 summary: tensorboard.Summary, tag: str,
                                 epoch: int):
  # maximum distance in a trial is 160cm
  max_distance = 161
  positional_cam = {i: [] for i in range(max_distance)}

  # clip recorded positions to [0, 160] and group by cm
  positions = np.clip(results['positions'], a_min=0, a_max=160).astype(int)

  for s in range(len(results['trials'])):
    # iterate positions in a sample that has clear start and end trial index
    trial_ranges = utils.get_trial_ranges(results['trials'][s])
    if trial_ranges:
      for i in range(trial_ranges[0][0], trial_ranges[-1][1]):
        position = positions[s, i]
        positional_cam[position].append(results['cam'][s][:, i])

  # compute average attention per position
  cam = np.zeros(shape=(args.num_neurons, max_distance - 1), dtype=np.float32)
  for position in range(cam.shape[1]):
    attentions = positional_cam[position]
    if attentions:
      cam[:, position] = np.mean(np.stack(attentions), axis=0)

  plot.plot_positional_attention(args,
                                 cam=cam,
                                 summary=summary,
                                 tag=tag,
                                 epoch=epoch)


def plot_gradcam(args,
                 ds,
                 gan,
                 summary: tensorboard.Summary,
                 epoch: int,
                 g_layer_name: str = 'block_9/add',
                 d_layer_name: str = 'output/attention'):
  if len(args.input_shape) != 3:
    return

  if not (utils.layer_exists(gan.G, g_layer_name) and
          utils.layer_exists(gan.F, g_layer_name)):
    print(f'GradCAM: generator layer {g_layer_name} not found.')
    return
  if not (utils.layer_exists(gan.X, d_layer_name) and
          utils.layer_exists(gan.Y, d_layer_name)):
    print(f'GradCAM: discriminator layer {d_layer_name} not found.')
    return

  G_results, F_results, X_results, Y_results = get_models_gradcam(
      args,
      ds=ds,
      gan=gan,
      g_layer_name=g_layer_name,
      d_layer_name=d_layer_name)

  plot.plot_generator(args,
                      results=G_results,
                      summary=summary,
                      tag=f'GradCAM_G(x)',
                      epoch=epoch)

  plot.plot_generator(args,
                      results=F_results,
                      summary=summary,
                      tag=f'GradCAM_F(y)',
                      epoch=epoch)

  plot.plot_discriminator(args,
                          results=X_results,
                          summary=summary,
                          tag='GradCAM_Dx(x)',
                          title=r'$D_X$(x)',
                          epoch=epoch)

  plot.plot_discriminator(args,
                          results=Y_results,
                          summary=summary,
                          tag='GradCAM_Dy(y)',
                          title=r'$D_Y$(y)',
                          epoch=epoch)


def plot_positional_gradcam(args,
                            ds,
                            gan,
                            summary: tensorboard.Summary,
                            epoch: int,
                            g_layer_name: str = 'block_9/add',
                            d_layer_name: str = 'output/attention',
                            ds_size: int = None):
  if args.dataset == 'horse2zebra' or len(args.input_shape) != 3:
    return

  if not (utils.layer_exists(gan.G, g_layer_name) and
          utils.layer_exists(gan.F, g_layer_name)):
    print(f'GradCAM: generator layer {g_layer_name} not found.')
    return
  if not (utils.layer_exists(gan.X, d_layer_name) and
          utils.layer_exists(gan.Y, d_layer_name)):
    print(f'GradCAM: discriminator layer {d_layer_name} not found.')
    return

  G_results, F_results, X_results, Y_results = get_models_gradcam(
      args,
      ds=ds,
      gan=gan,
      g_layer_name=g_layer_name,
      d_layer_name=d_layer_name,
      ds_size=ds_size)

  compute_positional_attention(args,
                               results=G_results,
                               summary=summary,
                               tag=f'GradCAM_G(x)_positional',
                               epoch=epoch)
  compute_positional_attention(args,
                               results=F_results,
                               summary=summary,
                               tag=f'GradCAM_F(y)_positional',
                               epoch=epoch)
  compute_positional_attention(args,
                               results=X_results,
                               summary=summary,
                               tag='GradCAM_Dx(x)_positional',
                               epoch=epoch)
  compute_positional_attention(args,
                               results=Y_results,
                               summary=summary,
                               tag='GradCAM_Dy(y)_positional',
                               epoch=epoch)
