from utils.hydra_utils import print_cfg
from utils.main_utils import (
    init_hydra_config,
    apply_random_seed,
    init_data_loader,
    init_model,
    init_trainer,
)
from utils.debug_utils import debug

import logging
import gc
import time
import torch

torch.set_float32_matmul_precision("high") # highest, high, medium

import warnings
warnings.filterwarnings("ignore", "Detected call of", UserWarning)


def main():

    cfg = init_hydra_config(mode="pretrain")
    apply_random_seed(cfg)
    # print_cfg(cfg)

    data_loaders = []

    cfg, train_loader = init_data_loader(cfg, mode="pretrain", is_train=True, is_test=False)
    data_loaders.append(train_loader)
    # debug(train_loader, "train")

    if cfg.TEST.KNN_VALIDATION:
        _, val_loader = init_data_loader(cfg, mode="pretrain", is_train=False, is_test=False)
        data_loaders.append(val_loader)
        # debug(val_loader, "val")
    
        # _, test_loader = init_data_loader(cfg, mode="pretrain", is_train=False, is_test=True)
        # debug(test_loader, "test")

    cfg, model = init_model(cfg)

    cfg, trainer = init_trainer(cfg)

    epochs = cfg.TRAINER.max_epochs
    precision = cfg.TRAINER.precision
    logging.info("Start Training: Total Epoch - {}, Precision: {}".format(epochs, precision))

    trainer.fit(model, *data_loaders)

    logging.info("Training Finished")

    del trainer, model, data_loaders, train_loader, val_loader
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(10)
    logging.info("Garbage collection : Done")

    logging.info(torch.cuda.memory_allocated())
    logging.info(torch.cuda.memory_reserved())
    
if __name__ == "__main__":
    main()