import os
import argparse
from configs.defaults import get_cfgs_defaults
import torch

from trainer import GaussianSQVAETrainer, VmfSQVAETrainer
from util import set_seeds, get_loader, get_loader_ecmi


def arg_parse():
    parser = argparse.ArgumentParser(
            description="main.py")
    parser.add_argument(
        "-c", "--config_file", default="", help="config file")
    parser.add_argument(
        "-ts", "--timestamp", default="", help="saved path (random seed + date)")
    parser.add_argument(
        "--save", action="store_true", help="save trained model")
    parser.add_argument(
        "--dbg", action="store_true", help="print losses per epoch")
    parser.add_argument(
        "--gpu", default="0", help="index of gpu to be used")
    parser.add_argument(
        '--ecmi', action='store_true', help="determine whether we use the eCMI settings or not")
    parser.add_argument(
        '--n', '-n', type=int, default=None, help='Number of training examples for the eCMI settings')
    parser.add_argument(
        '--K', '-K', type=int, default=128, help='Number of size_dict')
    parser.add_argument(
        '--d_dict', type=int, default=64, help='Number of dim_dict')
    parser.add_argument(
        "--seed", type=int, default=0, help="seed number for randomness (on 2n examples if we use the eCMI settings)")
    parser.add_argument(
        "--S_seed", type=int, default=0, help="seed number for randomness on train/test data split (ecmi)")
    args = parser.parse_args()
    return args


def load_config(args):
    cfgs = get_cfgs_defaults()
    config_path = os.path.join(os.path.dirname(__file__), "configs", args.config_file)
    print(config_path)
    cfgs.merge_from_file(config_path)
    cfgs.train.seed = args.seed
    cfgs.quantization.size_dict = args.K
    cfgs.quantization.dim_dict = args.d_dict
    cfgs.flags.save = args.save
    cfgs.flags.noprint = not args.dbg
    cfgs.path_data = cfgs.path
    cfgs.path = os.path.join(cfgs.path, cfgs.path_specific)
    if cfgs.model.name.lower() == "vmfsqvae":
        cfgs.quantization.dim_dict += 1
    cfgs.flags.var_q = not(cfgs.model.param_var_q=="gaussian_1" or
                                        cfgs.model.param_var_q=="vmf")
    cfgs.freeze()
    flgs = cfgs.flags
    return cfgs, flgs


if __name__ == "__main__":
    print("main.py")
    
    ## Experimental setup
    args = arg_parse()
    if args.gpu != "":
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cfgs, flgs = load_config(args)
    print("[Checkpoint path] "+cfgs.path)
    print(cfgs)
    
    ## Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if args.ecmi == False:
        set_seeds(args.seed)
        ## Data loader
        train_loader, val_loader, test_loader = get_loader(
            cfgs.dataset.name, cfgs.path_dataset, cfgs.train.bs, cfgs.nworker)
        print("Complete dataload")
    elif args.ecmi == True:
        ## Data loader
        train_loader, val_loader, test_loader = get_loader_ecmi(
            args, cfgs.dataset.name, cfgs.path_dataset, cfgs.train.bs, cfgs.nworker)
        print("Complete dataload under the eCMI settings")

    ## Trainer
    print("=== {} ===".format(cfgs.model.name.upper()))
    if cfgs.model.name == "GaussianSQVAE":
        trainer = GaussianSQVAETrainer(args, cfgs, flgs, train_loader, val_loader, test_loader)
    elif cfgs.model.name == "VmfSQVAE":
        trainer = VmfSQVAETrainer(args, cfgs, flgs, train_loader, val_loader, test_loader)
    else:
        raise Exception("Undefined model.")

    ## Main
    if args.timestamp == "":
        trainer.main_loop()
    if flgs.save:
        trainer.load(args.timestamp)
        print("Best models were loaded!!")
        res_test = trainer.test()

