import os
import argparse
from configs.defaults import get_cfgs_defaults
import torch

from trainer import OTVAETrainer, OTVAEMaskTrainer
from util import set_seeds, get_loader, get_logger


def arg_parse():
    parser = argparse.ArgumentParser(description='OT-VAE training on vision dataset',
                                     add_help=True,
                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument(
        "-c", "--config_file", default="", help="config file")
    parser.add_argument(
        "-ts", "--timestamp", default="", help="saved path (random seed + date)")
    parser.add_argument(
        "--save", action="store_true", help="save trained model")
    parser.add_argument(
        "--dbg", action="store_true", help="print losses per epoch")
    parser.add_argument(
        "--gpu", default="0", help="index of gpu to be used")
    parser.add_argument(
        "--seed", type=int, default=0, help="seed number for randomness")
    
    ## OT parameters
    parser.add_argument(
        "--eps", type=float, default=None, help="EPS in the optimal transport")
    parser.add_argument(
        "--ot_iter", type=int, default=None, help="nb of iterations in the optimal transport")
    parser.add_argument(
        "--temp", type=float, default=None, help="Initial temperature in the optimal transport (which will be updated through backprop)")
    parser.add_argument(
        "--eta", type=float, default=None, help="ot loss")
    parser.add_argument(
        "--path_specific", type=str, default=None, help="output directory name")
    args = parser.parse_args()
    return args


def load_config(args):
    cfgs = get_cfgs_defaults()
    config_path = os.path.join(os.path.dirname(__file__), "configs", args.config_file)
    cfgs.merge_from_file(config_path)
    cfgs.train.seed = args.seed
    cfgs.flags.save = args.save
    cfgs.flags.noprint = not args.dbg
    cfgs.path_data = cfgs.path
    
    cfgs.path_specific = args.path_specific if args.path_specific is not None else cfgs.path_specific
    cfgs.quantization.eps = args.eps if args.eps is not None else cfgs.quantization.eps
    cfgs.quantization.ot_iter = args.ot_iter if args.ot_iter is not None else cfgs.quantization.ot_iter
    cfgs.quantization.temp = args.temp if args.temp is not None else cfgs.quantization.temp
    cfgs.loss.eta = args.eta if args.eta is not None else cfgs.loss.eta
    
    
    cfgs.path = os.path.join(cfgs.path, cfgs.path_specific)
    cfgs.freeze()
    flgs = cfgs.flags
    return cfgs, flgs


if __name__ == "__main__":
    print("main.py")
    
    ## Experimental setup
    args = arg_parse()
    if args.gpu != "":
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cfgs, flgs = load_config(args)
    print("[Checkpoint path] "+cfgs.path)
    
    ## Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seeds(args.seed)
    
    ## Data loader
    train_loader, val_loader, test_loader = get_loader(
        cfgs.dataset.name, cfgs.path_dataset, cfgs.train.bs, cfgs.nworker)
    print("Complete dataload")

    ## Trainer
    print("=== {} ===".format(cfgs.model.name.upper()))
    if cfgs.model.name == "OTVAE":
        trainer = OTVAETrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    elif cfgs.model.name == "OTVAEMask":
        trainer = OTVAEMaskTrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    else:
        raise Exception("Undefined model.")

    ## Main
    if args.timestamp == "":
        trainer.main_loop()
    if flgs.save:
        trainer.load(args.timestamp)
        print("Best models were loaded!!")
        res_test = trainer.test()

