import torch
from options import *
from config import *
from model import *
import numpy as np
from dataset_loader_baseline import *
from sklearn.metrics import roc_curve,auc,precision_recall_curve, precision_score, recall_score
import warnings
warnings.filterwarnings("ignore")


def test(net, config, wandb_viz, test_loader, test_info, step, model_file = None):
    with torch.no_grad():
        net.eval()
        net.flag = "Test"
        if model_file is not None:
            net.load_state_dict(torch.load(model_file))

        load_iter = iter(test_loader)
        frame_gt = np.load("frame_label/gt-ucf.npy")
        frame_predict = None
        
        cls_label = []
        cls_pre = []
        temp_predict = torch.zeros((0)).cuda()
        
        for i in range(len(test_loader.dataset)):
            
            

            _data, _label, name = next(load_iter)
            
            _data = _data.cuda()
            _label = _label.cuda()
            
            res = net(_data)   
            a_predict = res["frame"]
            temp_predict = torch.cat([temp_predict, a_predict], dim=0)
            if (i + 1) % 10 == 0 :
                cls_label.append(int(_label))
                a_predict = temp_predict.mean(0).cpu().numpy()
                
                cls_pre.append(1 if a_predict.max()>0.5 else 0)          
                fpre_ = np.repeat(a_predict, 16)
                if frame_predict is None:         
                    frame_predict = fpre_
                else:
                    frame_predict = np.concatenate([frame_predict, fpre_])  
                temp_predict = torch.zeros((0)).cuda()
   
        fpr,tpr,_ = roc_curve(frame_gt, frame_predict)
        auc_score = auc(fpr, tpr)
    
        corrent_num = np.sum(np.array(cls_label) == np.array(cls_pre), axis=0)
        accuracy = corrent_num / (len(cls_pre))
        
        precision, recall, th = precision_recall_curve(frame_gt, frame_predict,)
        ap_score = auc(recall, precision)
        abnormal_frames = frame_predict[frame_gt == 1]
        avg_ab_socore = np.mean(abnormal_frames) 
        normal_frames = frame_predict[frame_gt == 0]
        avg_n_socore = np.mean(normal_frames) 

        frame_gt_ab = frame_gt[0:461856]
        frame_predict_ab = frame_predict[0:461856]
        n_f_ab_video = (frame_gt_ab == 0) 
        normal_frames_from_ab_video = frame_predict_ab[n_f_ab_video]
        avg_normal_frames_from_ab_video = np.mean(normal_frames_from_ab_video)

        frame_gt_n = frame_gt[461856:]
        frame_predict_n = frame_predict[461856:]
        n_f_n_video = (frame_gt_n == 0) 
        normal_frames_from_n_video = frame_predict_n[n_f_n_video]
        avg_normal_frames_from_n_video = np.mean(normal_frames_from_n_video)


        log_dict = {
            'roc_auc': auc_score,
            'accuracy': accuracy,
            'pr_auc': ap_score,
            'avg_anomaly_score_ab': avg_ab_socore,
            'avg_anomaly_score_n': avg_n_socore,
            'avg_anomaly_score_n_from_ab_vid': avg_normal_frames_from_ab_video,
            'avg_anomaly_score_n_from_n_vid': avg_normal_frames_from_n_video
        }

        thresholds = np.arange(0.1, 1.0, 0.1)

        for thresh in thresholds:
            y_pred = (frame_predict >= thresh).astype(int)
            precision = precision_score(frame_gt, y_pred, zero_division=0)
            recall = recall_score(frame_gt, y_pred, zero_division=0)
            log_dict[f'precision_{thresh:.1f}'] = precision
            log_dict[f'recall_{thresh:.1f}'] = recall


        wandb_viz.run.log(log_dict)
        test_info["step"].append(step)
        test_info["auc"].append(auc_score)
        test_info["ap"].append(ap_score)
        # test_info["ac"].append(accuracy)
        