import os
import argparse
import numpy as np
from tqdm import tqdm
from time import time
import tensorflow as tf
from shutil import rmtree

from cyclegan.utils import utils, spike_helper
from cyclegan.utils.tensorboard import Summary
from cyclegan.models.registry import get_models
from cyclegan.utils.dataset import get_datasets
from cyclegan.alogrithms.registry import get_algorithm


def initialize_strategy(args):
  num_gpus = utils.get_num_gpus()

  # initialize MirroredStrategy when there are more than 1 available GPUs
  if num_gpus == 0:
    strategy = tf.distribute.OneDeviceStrategy(device='/cpu:0')
  elif num_gpus == 1:
    strategy = tf.distribute.OneDeviceStrategy(device='/gpu:0')
  else:
    strategy = tf.distribute.MirroredStrategy()

  num_devices = strategy.num_replicas_in_sync
  args.global_batch_size = num_devices * args.batch_size

  if args.verbose:
    print(f'\nnumber of compute devices: {num_devices}')

  return strategy


def train(args, ds, gan, summary, epoch: int):
  results = {}
  for x, y in tqdm(ds,
                   desc='Train',
                   total=args.train_steps,
                   disable=args.verbose == 0):
    result = gan.distributed_train_step(x['data'], y['data'])
    utils.update_dict(results, result)
  for key, value in results.items():
    summary.scalar(key, tf.reduce_mean(value), step=epoch, mode=0)


def validate(args, ds, gan, summary, epoch: int):
  results = {}
  for x, y in tqdm(ds,
                   desc='Validation',
                   total=args.val_steps,
                   disable=args.verbose == 0):
    result = gan.distributed_validation_step(x['data'], y['data'])
    utils.update_dict(results, result)
  for key, value in results.items():
    if key.startswith('error'):
      value = tf.concat(value, axis=-1)
    results[key] = tf.reduce_mean(value).numpy()
    summary.scalar(key, results[key], step=epoch, mode=1)
  return results


def main(args):
  if args.clear_output_dir and os.path.exists(args.output_dir):
    rmtree(args.output_dir)

  tf.keras.backend.clear_session()

  np.random.seed(args.seed)
  tf.random.set_seed(args.seed)

  if args.mixed_precision:
    print('\nenable mixed precision training')
    tf.keras.mixed_precision.set_global_policy('mixed_float16')

  strategy = initialize_strategy(args)

  train_ds, val_ds, test_ds, sample_ds = get_datasets(args)

  # create distributed datasets
  train_ds = strategy.experimental_distribute_dataset(train_ds)
  val_ds = strategy.experimental_distribute_dataset(val_ds)

  summary = Summary(args)

  # directories to store generated samples and model checkpoints
  args.samples_dir = os.path.join(args.output_dir, 'samples')

  G, Y = get_models(args,
                    strategy=strategy,
                    summary=summary,
                    g_name='generator_G',
                    d_name='discriminator_Y')
  F, X = get_models(args,
                    strategy=strategy,
                    summary=summary,
                    g_name='generator_F',
                    d_name='discriminator_X')

  gan = get_algorithm(args, G=G, F=F, X=X, Y=Y, strategy=strategy)

  utils.save_args(args)

  epoch = gan.load_checkpoint()

  utils.plot(args,
             test_ds=test_ds,
             sample_ds=sample_ds,
             gan=gan,
             summary=summary,
             epoch=epoch)

  while (epoch := epoch + 1) < args.epochs + 1:
    print(f'Epoch {epoch:03d}/{args.epochs:03d}')

    start = time()
    train(args, ds=train_ds, gan=gan, summary=summary, epoch=epoch)
    metrics = validate(args, ds=val_ds, gan=gan, summary=summary, epoch=epoch)
    elapse = time() - start

    summary.scalar('model/elapse', elapse, step=epoch)

    print(f'Dx(X): {metrics["critic/Dx(X)"]:.02f}\t\t'
          f'Dx(F(Y)): {metrics["critic/Dx(F(Y))"]:.02f}\t\t'
          f'Dy(Y): {metrics["critic/Dy(Y)"]:.02f}\t\t'
          f'Dy(G(X)): {metrics["critic/Dy(G(X))"]:.02f}\n'
          f'MAE(X, F(G(X))): {metrics["MAE/MAE(X, F(G(X)))"]:.04f}\t\t'
          f'MAE(X, F(X)): {metrics["MAE/MAE(X, F(X))"]:.04f}\n'
          f'MAE(Y, G(F(Y))): {metrics["MAE/MAE(Y, G(F(Y)))"]:.04f}\t\t'
          f'MAE(Y, G(Y)): {metrics["MAE/MAE(Y, G(Y))"]:.04f}\n'
          f'Elapse: {elapse:.0f}s\n')

    if epoch % 5 == 0 or epoch == args.epochs:
      gan.save_checkpoint(args, epoch=epoch)

    if epoch % 10 == 0 or epoch == args.epochs:
      utils.plot(args,
                 test_ds=test_ds,
                 sample_ds=sample_ds,
                 gan=gan,
                 summary=summary,
                 epoch=epoch)
      if args.dataset != 'horse2zebra':
        utils.save_samples(args, gan=gan, ds=test_ds)

  if args.dataset != 'horse2zebra' and args.deconvolve_samples:
    spike_helper.deconvolve_samples(args)

  summary.close()


if __name__ == '__main__':
  parser = argparse.ArgumentParser()
  # data settings
  parser.add_argument('--dataset', type=str, required=True)
  parser.add_argument('--output_dir', type=str, default='runs')
  parser.add_argument('--synthetic_data',
                      action='store_true',
                      help='construct and use synthetic data')
  parser.add_argument('--synthetic_shift',
                      type=int,
                      default=8,
                      help='the amount of shift in synthetic data')
  parser.add_argument('--neuron_order_json',
                      type=str,
                      default=None,
                      help='path to json file with neuron orders')

  # algorithm hyper-parameters
  parser.add_argument('--algorithm', type=str, default='lsgan')
  parser.add_argument('--lambda_cycle',
                      type=float,
                      default=10.0,
                      help='cycle-consistent loss coefficient')
  parser.add_argument('--lambda_identity',
                      type=float,
                      default=None,
                      help='identity loss coefficient')
  parser.add_argument('--lambda_equivalent',
                      type=float,
                      default=0.0,
                      help='equivalent loss coefficient')
  parser.add_argument('--lambda_gp',
                      default=10.0,
                      type=float,
                      help='gradient penalty coefficient, only applicable '
                      'if algorithms with gradient penalty is used.')
  parser.add_argument('--error',
                      default='mae',
                      choices=['mse', 'mae', 'huber'],
                      help='error function to use in cycle-consistent and '
                      'identity losses')
  parser.add_argument('--label_smoothing', action='store_true')

  # model hyper-parameters
  parser.add_argument('--model', type=str, default='unet')
  parser.add_argument('--num_filters', type=int, default=8)
  parser.add_argument('--kernel_size', type=int, nargs='+', default=3)
  parser.add_argument('--reduction_factor', type=int, nargs='+', default=2)
  parser.add_argument('--activation', type=str, default='lrelu')
  parser.add_argument('--normalization', type=str, default='instancenorm')
  parser.add_argument('--spectral_norm', action='store_true')
  parser.add_argument('--dropout', type=float, default=0.0)
  parser.add_argument('--phase_shuffle', type=int, default=0)
  parser.add_argument('--patchgan', action='store_true')
  parser.add_argument('--input_2d',
                      action='store_true',
                      help='use 2D input instead of 3D')

  # training hyper-parameters
  parser.add_argument('--epochs', default=200, type=int)
  parser.add_argument('--batch_size', default=64, type=int)
  parser.add_argument('--g_lr',
                      default=1e-4,
                      type=float,
                      help='generators learning rate')
  parser.add_argument('--d_lr',
                      default=4e-4,
                      type=float,
                      help='discriminator learning rate')
  parser.add_argument('--g_updates',
                      default=1,
                      type=int,
                      help='number of generator updates per iteration')
  parser.add_argument('--d_updates',
                      default=1,
                      type=int,
                      help='number of discriminator updates per iteration')
  parser.add_argument('--mixed_precision', action='store_true')
  parser.add_argument('--seed', type=int, default=777)

  # plot settings
  parser.add_argument('--dpi', type=int, default=120)
  parser.add_argument('--save_plots',
                      action='store_true',
                      help='save plots to disk')
  parser.add_argument('--format',
                      type=str,
                      default='pdf',
                      help='file format when --save_plots')

  # misc
  parser.add_argument('--num_processors', default=8, type=int)
  parser.add_argument('--deconvolve_samples', action='store_true')
  parser.add_argument('--clear_output_dir', action='store_true')
  parser.add_argument('--verbose', default=1, type=int)

  params = parser.parse_args()
  if type(params.kernel_size) == list:
    params.kernel_size = tuple(params.kernel_size)
  if type(params.reduction_factor) == list:
    params.reduction_factor = tuple(params.reduction_factor)
  main(params)
