from utils.hydra_utils import print_cfg
from utils.main_utils import (
    init_hydra_config,
    apply_random_seed,
    init_data_loader,
    init_model,
    init_trainer,
    load_pretrained_config,
)
from utils.debug_utils import debug

import logging
import os


def main():

    cfg = init_hydra_config(mode="finetune")
    apply_random_seed(cfg)
    cfg = load_pretrained_config(cfg)
    print_cfg(cfg)

    data_loaders = []

    cfg, train_loader = init_data_loader(cfg, mode="finetune", is_train=True, is_test=False)
    data_loaders.append(train_loader)
    
    _, val_loader = init_data_loader(cfg, mode="finetune", is_train=False, is_test=False)
    data_loaders.append(val_loader)

    cfg, model = init_model(cfg)

    cfg, trainer = init_trainer(cfg)

    logging.info(
        f"Start Training: Total Epoch - {cfg.TRAINER.max_epochs}, Precision: {cfg.TRAINER.precision}"
    )

    trainer.fit(model, *data_loaders)

    cfg.FINETUNED_LOAD_FROM = os.path.join(cfg.CKPT_PATH, cfg.EXPR_NAME, "model.ckpt")
    _, best_model = init_model(cfg)

    _, test_loader = init_data_loader(cfg, mode="finetune", is_train=False, is_test=True)
    
    trainer.test(best_model, dataloaders = test_loader)
    

if __name__ == "__main__":
    main()