import cv2
import time
import torch
import numpy as np
from tools import builder
from utils.logger import *
from utils.transforms import get_transforms


def run_net(args, config, train_writer=None, val_writer=None):
    logger = get_logger(args.log_name)

    # build model
    base_model = builder.model_builder(config.model)
    if args.use_gpu:
        base_model.to(args.local_rank)
    base_model.eval()

    start_time = time.time()
    if args.text:
        text_query = args.query
        predict_points = base_model.text_condition_generation(text_query)
        print(text_query)
    elif args.img:
        img_path = args.img_path
        img = cv2.imread(img_path)
        img = get_transforms()['test'](img)
        img = img.unsqueeze(0).cuda()
        predict_points = base_model.image_condition_generation(img)
    else:
        raise NotImplementedError

    end_time = time.time()
    print('running time: ', end_time - start_time)
    np.save('generated_points.npy', predict_points.cpu().numpy())
    print('Successfully save completion data to generated_points.npy')
