import numpy as np
import typing as t
import matplotlib.cm as cm
import matplotlib.pyplot as plt

from cyclegan.utils import utils
from cyclegan.utils import tensorboard


def area_average_attention(inputs: np.ndarray, heatmap: np.ndarray,
                           reward_range: t.Tuple[int, int]):
  means = [
      np.mean(heatmap[:, reward_range[0]]),
      np.mean(heatmap[:, reward_range[0]:reward_range[1]]),
      np.mean(heatmap[:, reward_range[1]:])
  ]
  total = np.sum(means) + 1e-6
  means = [f'{(mean / total)*100:.1f}%' for mean in means]
  # calculate mid-points for the 3 zones
  locations = [
      reward_range[0] / 2,
      ((reward_range[1] - reward_range[0]) / 2) + reward_range[0],
      ((inputs.shape[1] - reward_range[1]) / 2) + reward_range[1]
  ]
  return means, locations


def plot_trial(args,
               trial_id: int,
               inputs: np.ndarray,
               activation: np.ndarray,
               licks: np.ndarray,
               reward_zones: np.ndarray,
               tag: str,
               summary,
               epoch: int = 0):
  reward_ranges = utils.get_reward_ranges(reward_zones)
  assert len(reward_ranges) == 1

  cmap, aspect = 'gray', 'auto'
  figure, axes = plt.subplots(nrows=2,
                              ncols=1,
                              gridspec_kw={
                                  'height_ratios': [1, 1],
                                  'wspace': 0.025,
                                  'hspace': 0.1
                              },
                              figsize=(6, 4),
                              sharex=True,
                              dpi=args.dpi)

  axes[0].set_title(f'Trial #{trial_id:03d}')
  axes[0].imshow(inputs, cmap=cmap, aspect=aspect)
  axes[0].set_ylabel('Neuron')
  tensorboard.plot_reward_zones(axes[0], reward_ranges=reward_ranges)
  tensorboard.plot_licks(axes[0], licks=np.nonzero(licks)[0])

  activation = activation / (np.max(activation) + 1e-6)
  overlay = tensorboard.superimpose(heatmap=tensorboard.gray2rgb(activation),
                                    background=inputs)
  axes[1].imshow(overlay, aspect=aspect)
  axes[1].set_ylabel('Neuron')
  axes[1].set_xlabel('Time-step')
  tensorboard.plot_reward_zones(axes[1], reward_ranges=reward_ranges)

  zone_averages, zone_locations = area_average_attention(
      inputs, heatmap=activation, reward_range=reward_ranges[0])
  axes[1].set_xticks(zone_locations)
  axes[1].set_xticklabels(zone_averages)

  ax1_right = axes[0].twinx()
  ax1_right.set_ylabel('Input', rotation=270, va='center')
  ax1_right.set_yticks([])
  tensorboard.remove_top_right_spines(ax1_right)

  ax2_right = axes[1].twinx()
  ax2_right.set_ylabel('GradCAM', rotation=270, va='center')
  ax2_right.set_yticks([])
  tensorboard.remove_top_right_spines(ax2_right)

  summary.figure(tag, figure=figure, step=epoch, mode=1)


def plot_trials(args,
                inputs: np.ndarray,
                activation: np.ndarray,
                licks: np.ndarray,
                reward_zones: np.ndarray,
                trials: np.ndarray,
                tag: str,
                summary,
                epoch: int = 0,
                max_plots: int = 5):
  trial_ranges = utils.get_trial_ranges(trials)
  for i, (trial_start, trial_end) in enumerate(trial_ranges):
    if i >= max_plots:
      break
    trial_id = trials[trial_start]
    plot_trial(args,
               trial_id=trial_id,
               inputs=inputs[:, trial_start:trial_end],
               activation=activation[:, trial_start:trial_end],
               licks=licks[trial_start:trial_end],
               reward_zones=reward_zones[trial_start:trial_end],
               tag=f'{tag}/trial_#{trial_id:03d}',
               summary=summary,
               epoch=epoch)


def plot_horse2zebra(args,
                     inputs: np.ndarray,
                     cam: np.ndarray,
                     summary: tensorboard.Summary,
                     tag: str,
                     epoch: int,
                     title: str = ''):
  aspect, fontsize = 'equal', 10
  figure, axis = plt.subplots(nrows=1,
                              ncols=1,
                              gridspec_kw={
                                  'wspace': 0.1,
                                  'hspace': 0.1
                              },
                              figsize=(5, 5),
                              dpi=args.dpi)

  scale = lambda x: ((x + 1) * 127.5).astype(np.uint8)
  superimpose = tensorboard.superimpose(
      heatmap=scale(tensorboard.gray2rgb(cam)),
      background=scale(inputs),
      alpha=0.4,
  )
  axis.imshow(superimpose, cmap=None, aspect=aspect)

  xticks_loc = np.linspace(0, superimpose.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axis,
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         fontsize=fontsize)
  yticks_loc = np.linspace(0, superimpose.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axis,
                         ticks_loc=yticks_loc,
                         ticks=yticks_loc.astype(int),
                         fontsize=fontsize)
  axis.tick_params(axis='both', which='both', length=0)

  if title:
    axis.set_title(title)

  summary.figure(tag, figure=figure, step=epoch, mode=1)


def plot_segment(args,
                 inputs: np.ndarray,
                 cam: np.ndarray,
                 reward_zones: np.ndarray,
                 summary: tensorboard.Summary,
                 tag: str,
                 title: str = None,
                 epoch: int = 0):
  cmap, aspect = 'gray', 'auto',
  figure, axes = plt.subplots(nrows=2,
                              ncols=1,
                              gridspec_kw={
                                  'height_ratios': [1, 1],
                                  'wspace': 0.025,
                                  'hspace': 0.1
                              },
                              sharex=True,
                              figsize=(6, 4),
                              dpi=args.dpi)

  axes[0].imshow(inputs, cmap=cmap, aspect=aspect)

  superimpose = tensorboard.superimpose(heatmap=tensorboard.gray2rgb(cam),
                                        background=inputs)
  superimpose = superimpose / (np.max(superimpose) + 1e-6)
  axes[1].imshow(superimpose, aspect=aspect)

  # plot x label
  xticks_loc = np.linspace(0, superimpose.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[1],
                         ticks_loc=xticks_loc,
                         ticks=(xticks_loc / args.frame_rate).astype(int),
                         label='Time (s)')
  # plot y label
  yticks_loc = np.linspace(0, superimpose.shape[0] - 1, 5)
  yticks = args.neuron_order[yticks_loc.astype(int)] + 1
  tensorboard.set_yticks(axis=axes[0],
                         ticks_loc=yticks_loc,
                         ticks=yticks,
                         label='Neuron')
  tensorboard.set_yticks(axis=axes[1],
                         ticks_loc=yticks_loc,
                         ticks=yticks,
                         label='Neuron')

  tensorboard.set_right_label(axes[0], label='Input')
  tensorboard.set_right_label(axes[1], label='GradCAM')

  axes[0].tick_params(axis='both', which='both', length=0)
  axes[1].tick_params(axis='both', which='both', length=0)

  if title:
    axes[0].set_title(title)

  if reward_zones is not None:
    tensorboard.plot_reward_zones(
        axes[0],
        reward_ranges=utils.get_reward_ranges(reward_zones),
    )

  summary.figure(tag, figure=figure, step=epoch, mode=1)


def plot_generator(args,
                   results: t.Dict[str, np.ndarray],
                   summary,
                   tag: str,
                   epoch: int = 0):
  for t in range(results['cam'].shape[0]):
    if args.dataset == 'horse2zebra':
      plot_horse2zebra(args,
                       inputs=results['inputs'][t],
                       cam=results['cam'][t],
                       tag=f'{tag}/trial_{t:03d}',
                       summary=summary,
                       epoch=epoch)
    else:
      plot_segment(args,
                   inputs=results['inputs'][t],
                   cam=results['cam'][t],
                   reward_zones=results['reward_zones'][t]
                   if 'reward_zones' in results else None,
                   tag=f'{tag}_segment/{t:03d}',
                   summary=summary,
                   epoch=epoch)


def plot_discriminator(args,
                       results: t.Dict[str, np.ndarray],
                       summary,
                       tag: str,
                       title: str,
                       epoch: int = 0):
  for t in range(results['cam'].shape[0]):
    if args.dataset == 'horse2zebra':
      plot_horse2zebra(args,
                       inputs=results['inputs'][t],
                       cam=results['cam'][t],
                       tag=f'{tag}/trial_{t:03d}',
                       summary=summary,
                       epoch=epoch,
                       title=rf'{title} = {results["prediction"][t]:.2f}')
    else:
      plot_segment(args,
                   inputs=results['inputs'][t],
                   cam=results['cam'][t],
                   reward_zones=results['reward_zones'][t],
                   summary=summary,
                   tag=f'{tag}_segment/{t:03d}',
                   title=rf'{title} = {results["prediction"][t]:.2f}',
                   epoch=epoch)


def plot_positional_attention(args, cam: np.ndarray,
                              summary: tensorboard.Summary, tag: str,
                              epoch: int):
  figure, axes = plt.subplots(nrows=1,
                              ncols=2,
                              gridspec_kw={
                                  'width_ratios': [1, 0.03],
                                  'wspace': 0.025,
                                  'hspace': 0.025
                              },
                              figsize=(5, 2),
                              dpi=args.dpi)

  cam = cam / (np.max(cam) + 1e-6)
  axes[0].imshow(tensorboard.gray2rgb(cam), aspect='auto')

  xticks_loc = np.linspace(0, cam.shape[1] - 1, 5)
  tensorboard.set_xticks(axis=axes[0],
                         ticks_loc=xticks_loc,
                         ticks=xticks_loc.astype(int),
                         label='Distance (cm)')
  yticks_loc = np.linspace(0, cam.shape[0] - 1, 5)
  tensorboard.set_yticks(axis=axes[0],
                         ticks_loc=yticks_loc,
                         ticks=args.neuron_order[yticks_loc.astype(int)] + 1,
                         label='Neuron')
  axes[0].tick_params(axis='both', which='both', length=1)

  # plot reward zone at 120 to 140cm
  tensorboard.plot_reward_zones(axis=axes[0],
                                reward_ranges=[(120, 140)],
                                alpha=1.0)
  figure.colorbar(cm.ScalarMappable(cmap=tensorboard.JET), cax=axes[1])
  axes[1].tick_params(axis='both', which='both', length=-1)

  summary.figure(tag, figure=figure, step=epoch, mode=1)
