import torch
import numpy as np
from dataset_loader import XDVideo
from options import parse_args
import pdb
from config import Config
import utils
import os
from model import WSAD
from tqdm import tqdm
from dataset_loader import data
from sklearn.metrics import roc_curve,auc,precision_recall_curve
import matplotlib.pyplot as plt

def valid(net, config, test_loader, model_file = None):
    with torch.no_grad():
        net.eval()
        net.flag = "Test"
        if model_file is not None:
            net.load_state_dict(torch.load(model_file))
            
        pre_dict = {}
        gt_dict = {}
        load_iter = iter(test_loader)
        frame_gt = np.load("frame_label/xd_gt.npy")
        frame_predict = None
        ab_sim_all = None
        n_sim_all = None

        cls_label = []
        cls_pre = []
        count = 0

        data11 = np.load('frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_0_725.npy')
        data22 = np.load('/frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_1_725.npy')
        data33 = np.load('frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_2_725.npy')
        
        save_dir = '/AAAI_2026_Results/xd_new_0_03_0.95_un_th_0.3_0.1'


        data_stack = np.stack([data11, data22, data33], axis=0)  # Shape: (3, num_data_points)

        mean_data = np.mean(data_stack, axis=0)  # Shape: (num_data_points,)
        variance_data = np.std(data_stack, axis=0)  # Shape: (num_data_points,)
        
        for i in range(len(test_loader.dataset)//5):

            _data, _label, _idx, _name = next(load_iter)
            if isinstance(_name, tuple):
                first_file = _name[0]
            else:
                first_file = _name
            base_name = os.path.basename(first_file)  # e.g., "Bad.Boys.1995__#01-11-55_01-12-40_label_G-B2-B6__0.npy"
            name_parts = base_name.split("__")
            unique_id = name_parts[-1].split(".")[0]
            _name = f"{unique_id}_{i}"
            
            _data = _data.cuda()
            _label = _label.cuda()

            cls_label.append(int(_label[0]))
            res = net(_data)   
            a_predict = res["frame"].cpu().numpy().mean(0)   
            ab_sim = res["A_sim"]
            ab_sim = ab_sim.cpu().numpy().mean(0)

            n_sim = res["N_sim"]
            n_sim = n_sim.cpu().numpy().mean(0)

            cls_pre.append(1 if a_predict.max()>0.5 else 0)          
            fpre_ = np.repeat(a_predict,16)
            pl = len(fpre_)
            pre_dict[i] = fpre_
            gt_dict[i] = frame_gt[count: count+pl]
            frame_gt_video = frame_gt[count:count + pl]
            mean_prediction = mean_data[count:count + pl]
            if frame_predict is None:         
                frame_predict = fpre_
                ab_sim_all = np.repeat(ab_sim, 16)
                n_sim_all = np.repeat(n_sim, 16)
            else:
                frame_predict = np.concatenate([frame_predict, fpre_])   
                ab_sim_all = np.concatenate([ab_sim_all, np.repeat(ab_sim, 16)])   
                n_sim_all = np.concatenate([n_sim_all, np.repeat(n_sim, 16)])   


  

            count = count + pl
        
        np.save('frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_en_725.npy', mean_data)
        np.save('frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_2_725_ab_sim.npy', ab_sim_all)
        np.save('frame_label/xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_2_725_n_sim.npy', n_sim_all)


        # np.save('frame_label/xd_pre_dict.npy', pre_dict)
        # np.save('frame_label/xd_gt_dict.npy', gt_dict)
        fpr,tpr,_ = roc_curve(frame_gt, mean_data)
        auc_score = auc(fpr, tpr)
        print("auc:{}".format(auc_score))
        corrent_num = np.sum(np.array(cls_label) == np.array(cls_pre), axis=0)
        accuracy = corrent_num / (len(cls_pre))
        precision, recall, th = precision_recall_curve(frame_gt, mean_data,)
        ap_score = auc(recall, precision)

        print("accuracy:{}".format(accuracy))
        print("ap_score:{}".format(ap_score))
         
if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()
    config = Config(args)
    worker_init_fn = None
    config.len_feature = 1024
    if config.seed >= 0:
        utils.set_seed(config.seed)
        worker_init_fn = np.random.seed(config.seed)
    net = WSAD(config.len_feature, flag = "Test", a_nums = 60, n_nums = 60)
    net = net.cuda()
    test_loader = data.DataLoader(
        XDVideo(root_dir = config.root_dir, mode = 'Test', modal = config.modal, num_segments = config.num_segments, len_feature = config.len_feature),
            batch_size = 5,
            shuffle = False, num_workers = config.num_workers,
            worker_init_fn = worker_init_fn)
    valid(net, config, test_loader, model_file = os.path.join(args.model_path, "xd_un_alpha_0_1_th_0_3_tem_q_0.03_0.95_2_725.pkl"))