# coding:utf-8
import logging
import json
import torch
import numpy as np

from cotk.wordvector import Glove
from cotk_extend import PredefinedLanguageGeneration

from utils import debug, try_cache, cuda_init, Storage
from NAGAN import NAGAN

def main(args, load_exclude_set, restoreCallback):
	logger = logging.getLogger()
	logger.handlers.pop()
	logging.basicConfig(\
		filename=0,\
		level=logging.DEBUG,\
		format='%(asctime)s %(filename)s[line:%(lineno)d] %(message)s',\
		datefmt='%H:%M:%S')

	if args.debug:
		debug()
	logging.info(json.dumps(args, indent=2))

	cuda_init(0, args.cuda)

	volatile = Storage()
	data_class = PredefinedLanguageGeneration
	data_arg = Storage()
	data_arg.file_id = args.datapath
	data_arg.max_sent_length = args.max_sent_length

	wordvec_class = Glove

	def load_dataset(data_arg, wvpath, embedding_size):
		dm = data_class(**data_arg)
		if wvpath:
			wv = wordvec_class(wvpath)
			wvvec = wv.load_matrix(embedding_size, dm.frequent_vocab_list)
		else:
			wvvec = np.random.randn(len(dm.frequent_vocab_list), embedding_size) * 0.1
		return dm, wvvec

	if args.cache:
		dm, wordvec = try_cache(load_dataset, (data_arg, args.wvpath, args.embedding_size), args.cache_dir, name="predefined_wordvec")
	else:
		dm, wordvec = load_dataset(data_arg, args.wvpath, args.embedding_size)

	volatile.dm = dm
	volatile.wordvec = wordvec
	volatile.load_exclude_set = load_exclude_set
	volatile.restoreCallback = restoreCallback

	param = Storage()
	param.args = args
	param.volatile = volatile

	model = NAGAN(param)
	if args.mode == "train":
		model.train_process()
		model.test_process()
	elif args.mode == "test":
		model.test_process()
	elif args.mode == "interpolate":
		model.interpolate()
	elif args.mode == "record":
		model.record()
	elif args.mode == "adddiff":
		model.adddiff()
	elif args.mode == "walkgrad":
		model.walkgrad()
	else:
		raise ValueError("Unknown mode")
