import os
import argparse
from configs.defaults import get_cfgs_defaults
import torch

from trainer import HierSQVAETrainer
from util import set_seeds, get_loader

def arg_parse():
    parser = argparse.ArgumentParser(
            description="main.py")
    parser.add_argument(
        "-c", "--config_file", default="", help="config file")
    parser.add_argument(
        "-ts", "--timestamp", default="", help="saved path (random seed + date)")
    parser.add_argument(
        "--save", action="store_true", help="save trained model")
    parser.add_argument(
        "--dbg", action="store_false", help="print losses per epoch")
    parser.add_argument(
        "--gpu", default="0", help="index of gpu to be used")
    parser.add_argument(
        "--seed", type=int, default=0, help="seed number for randomness")
    ## For thorough comparison
    parser.add_argument(
        "--size_dict", type=int, default=0, help="Codebook size (the number of code vectors per a layer)")
    parser.add_argument(
        "--num_layer", type=int, default=0, help="The number of latent layers (only for RSQVAE and RVQVAE)")
    args = parser.parse_args()
    return args


def load_config(args):
    cfgs = get_cfgs_defaults()
    config_path = os.path.join(os.path.dirname(__file__), "configs", args.config_file)
    print(config_path)
    cfgs.merge_from_file(config_path)
    cfgs.train.seed = args.seed
    cfgs.flags.save = args.save
    cfgs.flags.noprint = not args.dbg
    cfgs.path_data = cfgs.path
    cfgs.path = os.path.join(cfgs.path, cfgs.path_specific)
    if cfgs.model.name.lower() == "vmfsqvae":
        cfgs.quantization.dim_dict += 1
    cfgs.flags.var_q = not(cfgs.model.param_var_q in ["gaussian_1", "vmf"])

    ## For thorough comparison
    if cfgs.network.blocks_sq.count(',') == 0 and 'x' in cfgs.network.blocks_sq:
        res, num_layer = cfgs.network.blocks_sq.split('x')
        num_layer = int(num_layer)
        if args.num_layer > 0:
            cfgs.network.blocks_sq = res + 'x' + str(args.num_layer)
            num_layer = args.num_layer
        log_param_q_init = cfgs.model.log_param_q_init[0]
        size_dict = cfgs.quantization.size_dict[0]
        dim_dict = cfgs.quantization.dim_dict[0]
        cfgs.model.log_param_q_init = []
        cfgs.quantization.size_dict = []
        cfgs.quantization.dim_dict = []
        for i in range(num_layer):
            cfgs.model.log_param_q_init.append(log_param_q_init)
            cfgs.quantization.size_dict.append(size_dict)
            cfgs.quantization.dim_dict.append(dim_dict)
        print(cfgs.network.blocks_sq)
    if args.size_dict > 0:
        for i, _ in enumerate(cfgs.quantization.size_dict):
            cfgs.quantization.size_dict[i] = args.size_dict
        print(cfgs.quantization.size_dict)

    cfgs.freeze()
    flgs = cfgs.flags
    return cfgs, flgs


if __name__ == "__main__":
    print("main.py")
    
    ## Experimental setup
    args = arg_parse()
    if args.gpu != "":
        os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
    cfgs, flgs = load_config(args)
    print("[Checkpoint path] "+cfgs.path)
    print(cfgs)
    
    ## Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    set_seeds(args.seed)

    ## Data loader
    train_loader, val_loader, test_loader = get_loader(
        cfgs.dataset.name, cfgs.path_dataset, cfgs.train.bs, cfgs.nworker)
    print("Complete dataload")

    ## Trainer
    print("=== {} ===".format(cfgs.model.name.upper()))
    if cfgs.model.name == "GaussianSQVAE":
        trainer = GaussianSQVAETrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    elif cfgs.model.name in ["HierSQVAE", "ResSQVAE", "SQVAE2"]:
        trainer = HierSQVAETrainer(cfgs, flgs, train_loader, val_loader, test_loader)
    else:
        raise Exception("Undefined model.")

    ## Main
    if args.timestamp == "":
        trainer.main_loop()
    if flgs.save:
        trainer.load(args.timestamp)
        print("Best models were loaded!!")
        res_test = trainer.test()
        trainer.generate_reconstructions_paper(nrows=8, ncols=8)

