import os
import argparse
import platform
import numpy as np
import typing as t
from glob import glob
import tensorflow as tf

import matplotlib
if platform.system() == 'Darwin':
  matplotlib.use('TkAgg')
import matplotlib.cm as cm
import matplotlib.pyplot as plt

plt.style.use('seaborn-deep')

from cyclegan.utils.gradcam import gradcam
from cyclegan.utils.tensorboard import Summary
from cyclegan.models.registry import get_models
from cyclegan.alogrithms.registry import get_algorithm
from cyclegan.utils import utils, attention_gate, dataset


def main(args):
  np.random.seed(777)
  tf.random.set_seed(777)

  if not os.path.exists(args.output_dir):
    raise FileNotFoundError(f'{args.output_dir} not found.')

  print(f'loading {args.output_dir}...\n')

  utils.load_args(args)
  args.dataset = os.path.join('..', args.dataset)
  args.plots_dir = os.path.join(args.output_dir, 'plots')
  args.samples_dir = os.path.join(args.output_dir, 'samples')
  args.signals_filename = os.path.join(args.samples_dir, 'signals.h5')
  args.checkpoint_dir = os.path.join(args.output_dir, 'checkpoints')
  args.save_plots = True

  if not os.path.exists(args.plots_dir):
    os.makedirs(args.plots_dir)

  summary = Summary(args)

  _, _, _, ds = dataset.get_datasets(args, dry_run=False)

  G, Y = get_models(args,
                    g_name='generator_G',
                    d_name='discriminator_Y',
                    write_summary=False)
  F, X = get_models(args,
                    g_name='generator_F',
                    d_name='discriminator_X',
                    write_summary=False)
  gan = get_algorithm(args, G, F, X, Y)
  epoch = gan.load_checkpoint(args, expect_partial=True)
  epoch += 1

  gradcam.plot_gradcam(args, ds=ds, gan=gan, summary=summary, epoch=epoch)
  attention_gate.plot_attention_gate(args,
                                     ds=ds,
                                     gan=gan,
                                     summary=summary,
                                     epoch=epoch)

  print(f'\nFigures saved at {args.plots_dir}')


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  parser.add_argument('--output_dir', type=str, required=True)
  parser.add_argument('--verbose', type=int, default=1)
  parser.add_argument('--format', type=str, default='pdf')
  parser.add_argument('--dpi', type=str, default=120)
  main(parser.parse_args())
