from utils.hydra_utils import print_cfg
from utils.main_utils import (
    init_hydra_config,
    apply_random_seed,
    init_data_loader,
    init_model,
    init_trainer,
    load_pretrained_config,
)
from utils.debug_utils import debug

import logging
import gc
import time
import torch

torch.set_float32_matmul_precision("high") # highest, high, medium

import warnings
warnings.filterwarnings("ignore", "Detected call of", UserWarning)


def main():

    cfg = init_hydra_config(mode="extract")
    apply_random_seed(cfg)
    cfg = load_pretrained_config(cfg)
    # print_cfg(cfg)
    # raise ValueError

    cfg, test_loader = init_data_loader(cfg, mode="extract", is_train=False, is_test=True)
    # debug(test_loader, "test")

    cfg, model = init_model(cfg)
    # logging.info(model)

    cfg, trainer = init_trainer(cfg)

    trainer.test(model, dataloaders=test_loader)

    logging.info("Extracting Finished")

    del trainer, model, test_loader
    torch.cuda.empty_cache()
    gc.collect()
    time.sleep(10)
    logging.info("Garbage collection : Done")

    logging.info(torch.cuda.memory_allocated())
    logging.info(torch.cuda.memory_reserved())    

    
if __name__ == "__main__":
    main()