"""

"""

import argparse
import sys
import os
from functools import partial

import tensorflow as tf
from tensorpack import logger, StagingInput, QueueInput, SmartInit, \
                       ScheduledHyperParamSetter, DictRestore

from nets import get_model
from utils import parse_config
from attacker import get_attacker
from trainer import train, eval
# from learning_schedules import get_learning_rate
from utils import get_learning_rate_setter
from dataset import get_dataset
from nets import AdvModel

parser = argparse.ArgumentParser()
parser.add_argument('--load', help='Path to a model to load for evaluation or resuming training.')
parser.add_argument('--starting-epoch', help='The epoch to start with. Useful when resuming training.',
                        type=int, default=1)
parser.add_argument('--logdir', help='Directory suffix for models and training stats.',
                    default='', type=str)
parser.add_argument('--eval', action='store_true', help='Evaluate a model instead of training.')
parser.add_argument('--gpus', help='GPU used', default='0', type=str)
parser.add_argument('--saver-max-to-keep', help='Maximum models for saver to keep',
                    default=10, type=int)
parser.add_argument('--keep-checkpoint-every-n-hours', help='Keep ckpt every n hours',
                    default=0.5, type=float)
parser.add_argument('--config', help='Path of configuration file.', type=str)
args = parser.parse_args()

os.environ['CUDA_VISIBLE_DEVICE'] = args.gpus

if __name__ == '__main__':
  config_file = args.config
  if not os.path.isfile(config_file):
    raise ValueError("%s not found" % config_file)
  
  if not os.path.isdir(args.logdir):
    raise ValueError("%s is not a directory")

  if args.eval:
    logdir = os.path.join(args.logdir, '%s-eval' % os.path.basename(config_file))
  else:
    logdir = os.path.join(args.logdir, os.path.basename(config_file))
  if not os.path.isdir(logdir):
    os.mkdir(logdir)
  logger.set_logger_dir(logdir, 'k')
  logger.info(' '.join(sys.argv))

  callbacks = []
  configs = parse_config(config_file)
  auto_resume = True

  # Init dynamic activation variables
  if 'activation_init' in configs:
    key_words = configs['activation_init']['key_words']
    checkpoint_path = configs['activation_init']['checkpoint_path']
    reader = tf.train.NewCheckpointReader(checkpoint_path)
    restore_dict = {}

    for var, shape in reader.get_variable_to_shape_map().items():
      if key_words in var and (not 'Momentum' in var):
        restore_dict[var] = reader.get_tensor(var)
    if len(restore_dict) == 0:
      raise ValueError("No activation parameters found")
    logger.info("Init activation parameters: %s\n" % ' '.join(restore_dict.keys()))
    sessinit = DictRestore(restore_dict)
    auto_resume = False
  else:
    sessinit = None
  model = get_model(configs['model'],
                    robust_activation_trainable='activation_init' not in configs)
  if isinstance(model, AdvModel):
    attacker = get_attacker(configs['attacker'],
                            num_classes=configs['model']['num_classes'])
    model.set_attacker(attacker)
  model.set_height(configs['image_height'])
  model.set_width(configs['image_width'])
  if configs['train_dataset']['name'] == 'caltech':
    model.set_num_channels(1)
  else:
    model.set_num_channels(3)

  if args.eval:
    sessinit = SmartInit(args.load)
    get_val_dataflow_func = partial(get_dataset, config=configs['val_dataset'],
                                    batch_size=configs['batch_size'],
                                    is_training=False, shuffle=False)

    ds = get_val_dataflow_func()
    eval(model, sessinit, ds)

  else:
    if 'optimizer' in configs:
      opt_config = configs['optimizer']
      t = opt_config.pop('type')
      model.set_optimizer(t, opt_config)
    
    if 'weight_decay' in configs:
      model.set_weight_decay(configs['weight_decay'])

    gpus = [int(t) for t in args.gpus.strip().split(',')]
    gpus = list(range(len(gpus)))
    batch_size = configs['batch_size']
    train_dataflow = get_dataset(config=configs['train_dataset'],
                                batch_size=batch_size,
                                is_training=True, shuffle=True)
    steps_per_epoch = len(train_dataflow)
    train_dataflow = StagingInput(QueueInput(train_dataflow))
    get_val_dataflow_func = partial(get_dataset, config=configs['val_dataset'],
                                                batch_size=batch_size,
                                                is_training=False, shuffle=False)

    learning_rate_config = configs['learning_rate']
    learning_rate_cbs = get_learning_rate_setter(learning_rate_config)
    callbacks += learning_rate_cbs

    if 'temperature' in configs:
      temperature_config = configs['temperature']
      temp = tf.get_variable('temperature', 
                             initializer=temperature_config['init_value'],
                             trainable=False)
      tf.summary.scalar('temperature-summary', temp)
      callbacks.append(ScheduledHyperParamSetter('temperature',
                       temperature_config['steps'], interp='linear'))
    else:
      temp = tf.get_variable('temperature', initializer=1.0,
                            trainable=False)
    tf.summary.scalar('temperature-summary', temp)

    train(gpus=gpus,
          extra_callbacks=callbacks,
          saver_max_to_keep=args.saver_max_to_keep,
          keep_checkpoint_every_n_hours=args.keep_checkpoint_every_n_hours,
          attacker_step_size=configs['attacker_step_size'],
          attacker_epsilon=configs['attacker_epsilon'],
          model=model,
          sessinit=sessinit,
          auto_resume=auto_resume,
          train_dataflow=train_dataflow,
          steps_per_epoch=steps_per_epoch,
          get_val_dataflow_func=get_val_dataflow_func,
          starting_epoch=args.starting_epoch,
          max_epoch=configs['max_epoch'])

