import os
import argparse
import collections
from datetime import datetime

from config.hparams import *
from pretrain import VQAModel
from train import VisDialModel


def pretrain_model(args):
    hparams = HPARAMS
    
    root_dir = hparams["root_dir"]
    root_dir += "%s-%s" % (hparams["encoder"], hparams["decoder"])
    hparams.update(root_dir=root_dir)
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    root_dir = os.path.join(hparams["root_dir"], "%s/" % timestamp)
    hparams["root_dir"] = root_dir

    if args.result_folder is None:
        raise NotImplementedError

    save_dirpath = hparams["save_dirpath"] + "/" + args.save_folder
    result_dirpath = hparams["result_dirpath"] + "/vqa/" + args.result_folder
    hparams.update(save_dirpath=save_dirpath)
    hparams.update(result_dirpath=result_dirpath)

    hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)
    
    model = VQAModel(hparams)
    model.train()


def train_model(args):
    hparams = HPARAMS
    
    root_dir = hparams["root_dir"]
    root_dir += "%s-%s" % (hparams["encoder"], hparams["decoder"])
    hparams.update(root_dir=root_dir)
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    root_dir = os.path.join(hparams["root_dir"], "%s/" % timestamp)
    hparams["root_dir"] = root_dir

    if args.result_folder is None:
        raise NotImplementedError

    save_dirpath = hparams["save_dirpath"] + "/visdial/" + args.save_folder
    result_dirpath = hparams["result_dirpath"] + "/visdial/" + args.result_folder
    load_pthpath = hparams["load_pthpath"] + "checkpoints/testTypeRegMItune/" + args.load_pthpath
    hparams.update(save_dirpath=save_dirpath)
    hparams.update(result_dirpath=result_dirpath)
    hparams.update(load_pthpath=load_pthpath)
    
    hparams = collections.namedtuple("HParams", sorted(hparams.keys()))(**hparams)

    model = VisDialModel(hparams)
    model.train()


def main(args):
    if args.pretrain:
        print("VQA MODEL PRE-TRAIN\n")
        pretrain_model(args)
    else: 
        print("VISUAL DIALOG MODEL TRAIN\n")
        train_model(args)

if __name__ == "__main__":
    arg_parser = argparse.ArgumentParser(description="Variational Disentangled Attention for Regularized Visual Dialog")
    arg_parser.add_argument("--pretrain", dest="pretrain", type=str, default=False, 
                            help="Pre-train VQA model or not")
    arg_parser.add_argument("--save_folder", dest="save_folder", type=str, default="save_folder", 
                            help="Save folder name")
    arg_parser.add_argument("--result_folder", dest="result_folder", type=str, default="result_folder", 
                            help="Result folder name")

    
    args = arg_parser.parse_args()
    main(args)