# coding:utf-8

def run():
	import argparse
	import time
	import torch

	from utils import Storage, name_format, debug
	from utils import AnnealParameter as ap

	parser = argparse.ArgumentParser()
	args = Storage()

	parser.add_argument("--gp", type=float, default=1, help="Maxgp Weight")
	parser.add_argument('--tau', type=float, default=1, help="Softmax Temperature")
	parser.add_argument("--ema", type=float, default=1e-4, help="Parameter Averaging 1-gamma. For example, if ema = 1e-4, gamma = 0.9999.")
	parser.add_argument("--batch_size", type=int, default=64)
	parser.add_argument("--layers", type=int, default=5, help="generator transformer layers")
	parser.add_argument("--dislayers", type=int, default=5, help="discriminator transformer layers")
	parser.add_argument("--heads", type=int, default=4, help="transformer attention heads")
	parser.add_argument("--droprate", type=float, default=0.25, help="Training droprate")

	parser.add_argument("--gen_approx_mode", type=str, default="gumbel")
	parser.add_argument("--dis_approx_mode", type=str, default="gumbel")

	# an idea from https://arxiv.org/pdf/1909.13188.pdf
	parser.add_argument("--dr", type=float, default=1, help="Discriminator score regularizer Weight")
	parser.add_argument("--va", action='store_true', help="Enable vocabulary attention")

	parser.add_argument('--name', type=str, default=None,
		help='The name of your model, used for tensorboard, etc. Default: runXXXXXX_XXXXXX (initialized by current time)')
	parser.add_argument('--restore', type=str, default=None,
		help='Checkpoints name to load. \
			"NAME_last" for the last checkpoint of model named NAME. "NAME_best" means the best checkpoint. \
			You can also use "last" and "best", use last model you run by default. \
			Attention: "NAME_last" and "NAME_best" are not guaranteed to work when 2 models with same name run in the same time. \
			"last" and "best" are not guaranteed to work when 2 models run in the same time.\
			Default: None (don\'t load anything)')
	parser.add_argument('--mode', type=str, default="train",
		help='"train" or "test". Default: train')

	parser.add_argument('--datapath', type=str, default='./mscoco_data',
		help='Directory for data set. Default: ./data')
	parser.add_argument('--wvpath', type=str, default="resources://Glove300d",
		help="Directory for pretrained wordvector. Default: resources://Glove300d")
	parser.add_argument('--epoch', type=int, default=100,
		help="Epoch for training. Default: 100")

	parser.add_argument('--out_dir', type=str, default="./output",
		help='Output directory for test output. Default: ./output')
	parser.add_argument('--log_dir', type=str, default="./tensorboard",
		help='Log directory for tensorboard. Default: ./tensorboard')
	parser.add_argument('--model_dir', type=str, default="./model",
		help='Checkpoints directory for model. Default: ./model')
	parser.add_argument('--cache_dir', type=str, default="./cache",
		help='Checkpoints directory for cache. Default: ./cache')
	parser.add_argument('--cpu', action="store_true",
		help='Use cpu.')
	parser.add_argument('--debug', action='store_true',
		help='Enter debug mode (using ptvsd).')
	parser.add_argument('--cache', action='store_true',
		help='Use cache for speeding up load data and wordvec. (It may cause problems when you switch dataset.)')
	parser.add_argument('--seed', type=int, default=0,
		help='Random seed')
	parser.add_argument('--batch_per_epoch', type=int, default=1500,
		help='Batch number for one epoch')

	parser.add_argument('--no_restore_optimizer', action='store_false', dest="restore_optimizer",
		help='Do not restore parameters of optimizer')
	parser.add_argument('--no_restore_other_weights', action='store_false', dest="restore_other_weights",
		help='Do not restore parameters in model.param.other weights. Not used in this model.')

	cargs = parser.parse_args()

	# Editing following arguments to bypass command line.
	args.name = cargs.name or time.strftime("run%Y%m%d_%H%M%S", time.localtime())

	args.restore = cargs.restore
	args.mode = cargs.mode
	args.datapath = cargs.datapath
	args.wvpath = None if cargs.wvpath=="None" or cargs.wvpath == "" else cargs.wvpath
	args.epochs = cargs.epoch
	args.out_dir = cargs.out_dir
	args.log_dir = cargs.log_dir
	args.model_dir = cargs.model_dir
	args.cache_dir = cargs.cache_dir
	args.debug = cargs.debug
	args.cache = cargs.cache
	args.cuda = not cargs.cpu

	# The following arguments are not controlled by command line.
	args.restore_optimizer = cargs.restore_optimizer
	args.restore_other_weights = cargs.restore_other_weights
	load_exclude_set = []
	restoreCallback = None

	args.gen_approx_mode = cargs.gen_approx_mode
	args.dis_approx_mode = cargs.dis_approx_mode

	args.batch_per_epoch = cargs.batch_per_epoch
	args.embedding_size = 300
	args.va = cargs.va # vocabulary attention
	args.rc = 4  # relative embedding
	args.tf_size = 128
	args.tf_hidden_size = 256
	args.n_heads = cargs.heads
	args.n_layers = cargs.layers
	args.n_dislayers = cargs.dislayers
	args.windows = [-1] * 12
	args.z_size = 64
	args.droprate = cargs.droprate
	args.dis_droprate = 0.25
	args.input_droprate = 0
	args.ema_factor = cargs.ema
	args.tau = cargs.tau
	args.gp = cargs.gp
	args.dr = cargs.dr

	args.dis_lr = 5e-4
	args.gen_lr = 2e-4

	args.batch_size = cargs.batch_size
	args.grad_clip = 5
	args.show_sample = [0]  # sample showed tensorboard
	args.max_sent_length = 40
	args.checkpoint_steps = 6
	args.checkpoint_max_to_keep = 100

	print("seed", cargs.seed)
	import random
	random.seed(cargs.seed)
	import torch
	torch.manual_seed(cargs.seed)
	import numpy as np
	np.random.seed(cargs.seed)

	from main import main
	main(args, load_exclude_set, restoreCallback)

if __name__ == '__main__':
	run()
