import os
import argparse
import platform
import typing as t
import numpy as np

import matplotlib
if platform.system() == 'Darwin':
  matplotlib.use('TkAgg')
import seaborn as sns
import matplotlib.pyplot as plt

plt.style.use('seaborn-deep')

from cyclegan.utils import utils
from sort_neurons import get_coordinates


def get_order(filename: str, data: str):
  content = utils.load_json(filename)
  return content[data]


def plot_coordinates(args, filename: str, coordinates,
                     centers: t.List[np.ndarray], order: t.List[int]):

  colors = sns.color_palette('husl', len(centers))
  figure, axis = plt.subplots(nrows=1, ncols=1, figsize=(5, 5), dpi=args.dpi)

  axis.set_facecolor('black')
  axis.grid(False)
  plt.setp(axis, xticks=[], yticks=[])
  for n, neuron in enumerate(order):
    axis.plot(*coordinates[neuron], color=colors[n], alpha=0.5)
    axis.text(*centers[neuron],
              s=f'{n + 1}',
              color=colors[n],
              horizontalalignment='center',
              verticalalignment='center')

  filename = os.path.join(args.output_dir, filename)
  figure.savefig(filename, dpi=args.dpi, bbox_inches='tight', pad_inches=0.0)
  print(f'plot saved to {filename}')


def main(args):
  if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir)

  coordinates = get_coordinates(args.coordinate_filename)
  centers = [c.mean(axis=1) for c in coordinates]

  no_order = list(range(len(centers)))
  ae_order = get_order('runs/ST260/vr14/sl2048/001_8f/order.json', data='order')
  fr_order = get_order('runs/ST260/vr14/sl2048/firing_rate/order.json',
                       data='order')

  plot_coordinates(args,
                   filename=f'coordinates_original.{args.format}',
                   coordinates=coordinates,
                   centers=centers,
                   order=no_order)

  plot_coordinates(args,
                   filename=f'coordinates_ae.{args.format}',
                   coordinates=coordinates,
                   centers=centers,
                   order=ae_order)

  plot_coordinates(args,
                   filename=f'coordinates_fr.{args.format}',
                   coordinates=coordinates,
                   centers=centers,
                   order=fr_order)


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument(
      '--coordinate_filename',
      type=str,
      default='../dataset/data/vr_data/ST260/MC_20181117_P01.mat')
  parser.add_argument('--output_dir', type=str, default='plots')
  parser.add_argument('--format',
                      type=str,
                      default='pdf',
                      choices=['png', 'pdf', 'svg'])
  parser.add_argument('--dpi', type=int, default=120)
  main(parser.parse_args())
