from model_farsecnn import LitFARSECNN
from model_detect_farsecnn import LitDetectFARSECNN
from lightning.pytorch.loggers.neptune import NeptuneLogger
from lightning.pytorch.trainer.trainer import Trainer
import torch
import os
from utils.farsecnn_utils import load_farsecnn


DATASET = 'NCars'
LOG_MODE = None
if __name__ == '__main__':
    if torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")

    bs = 16

    cfg_path = 'configs/model/farsecnn_A.yaml'
    checkpoint_dir = ''
    checkpoint_path = os.path.join(checkpoint_dir,os.listdir(checkpoint_dir)[0])

    if 'Gen1' in DATASET:
        net = LitDetectFARSECNN(cfg_path, bs=bs, log_mode=LOG_MODE, dataset=DATASET).to(device)
    else:
        net = LitFARSECNN(cfg_path, bs=bs, log_mode=LOG_MODE, dataset=DATASET).to(device)
    net = load_farsecnn(net, checkpoint_path)
    print("Testing model: "+checkpoint_path)

    if LOG_MODE == 'neptune':
        logger = NeptuneLogger(
            api_key='',
            project='',
            with_id='',
            prefix=''
        )
    else:
        logger = None

    trainer = Trainer(logger = logger)
    trainer.test(model=net, verbose=True)
