import torch
from options import *
from config import *
from model import *
import numpy as np
from dataset_loader import *
from sklearn.metrics import roc_curve,auc,precision_recall_curve, precision_score, recall_score
import warnings
warnings.filterwarnings("ignore")
from sklearn.metrics import precision_score, recall_score


def test(net, config, wandb_viz, test_loader, test_info, step, model_file = None):
    with torch.no_grad():
        net.eval()
        net.flag = "Test"
        if model_file is not None:
            net.load_state_dict(torch.load(model_file))

        load_iter = iter(test_loader)
        frame_gt = np.load("frame_label/xd_gt.npy")
        frame_predict = None
        cls_label = []
        cls_pre = []
        for i in range(len(test_loader.dataset)//5):

            _data, _label, _ = next(load_iter)
            
            _data = _data.cuda()
            _label = _label.cuda()
            cls_label.append(int(_label[0]))
            res = net(_data)   
        
            a_predict = res["frame"].cpu().numpy().mean(0)   
            cls_pre.append(1 if a_predict.max()>0.5 else 0)          
            fpre_ = np.repeat(a_predict, 16)
            if frame_predict is None:         
                frame_predict = fpre_
            else:
                frame_predict = np.concatenate([frame_predict, fpre_])   

        return frame_predict
