import torch
import numpy as np
from dataset_loader import UCF_crime
from options import parse_args
import pdb
from config import Config
import utils
import os
from model import WSAD
from dataset_loader import data
from sklearn.metrics import precision_recall_curve, roc_curve, auc
from sklearn.metrics import recall_score


if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    config = Config(args)
    worker_init_fn = None
    config.len_feature = 1024

    if config.seed >= 0:
        utils.set_seed(config.seed)
        worker_init_fn = np.random.seed(config.seed)

    test_loader = data.DataLoader(
        UCF_crime(
            root_dir=config.root_dir,
            mode='Test',
            modal=config.modal,
            num_segments=config.num_segments,
            len_feature=config.len_feature,
        ),
        batch_size=1,
        shuffle=False,
        num_workers=config.num_workers,
        worker_init_fn=worker_init_fn,
    )

    # List of model checkpoints
    model_files = [
        'models/ucf_ucf_uncertainty_sch_0.2_0.1_525_3000.pkl',
        'models/ucf_ucf_uncertainty_sch_0.2_0.1_625_3000.pkl',
        'models/ucf_ucf_uncertainty_sch_0.2_0.1_725_3000.pkl',
    ]

    for model_file in model_files:
        print(f"\nRunning validation for {model_file}")
        net = WSAD(input_size=config.len_feature, flag="Test", a_nums=60, n_nums=60).cuda()

        # Run validation
        with torch.no_grad():
            net.eval()
            net.flag = "Test"
            net.load_state_dict(torch.load(model_file, map_location='cuda:0'))

            load_iter = iter(test_loader)
            frame_gt = np.load("frame_label/gt-ucf.npy")
            frame_predict = None
            ucf_pdict = {k: {} for k in [
                "Abuse", "Arrest", "Arson", "Assault", "Burglary", "Explosion",
                "Fighting", "RoadAccidents", "Robbery", "Shooting", "Shoplifting",
                "Stealing", "Vandalism", "Normal"
            ]}
            ucf_gdict = {k: {} for k in ucf_pdict}

            cls_label = []
            cls_pre = []
            temp_predict = torch.zeros((0)).cuda()
            count = 0

            for i in range(len(test_loader.dataset)):
                _data, _label, _name = next(load_iter)
                _name = _name[0]
                _data = _data.cuda()
                _label = _label.cuda()

                res = net(_data)
                a_predict = res["frame"]
                temp_predict = torch.cat([temp_predict, a_predict], dim=0)

                if (i + 1) % 10 == 0:
                    cls_label.append(int(_label))
                    a_predict = temp_predict.mean(0).cpu().numpy()
                    pl = len(a_predict) * 16

                    if "Normal" in _name:
                        ucf_pdict["Normal"][_name] = np.repeat(a_predict, 16)
                        ucf_gdict["Normal"][_name] = frame_gt[count:count + pl]
                    else:
                        ucf_pdict[_name[:-3]][_name] = np.repeat(a_predict, 16)
                        ucf_gdict[_name[:-3]][_name] = frame_gt[count:count + pl]

                    count += pl
                    cls_pre.append(1 if a_predict.max() > 0.5 else 0)

                    fpre_ = np.repeat(a_predict, 16)
                    if frame_predict is None:
                        frame_predict = fpre_
                    else:
                        frame_predict = np.concatenate([frame_predict, fpre_])

                    temp_predict = torch.zeros((0)).cuda()

            # Save outputs with model-specific filenames
            base_name = os.path.splitext(os.path.basename(model_file))[0]
            np.save(f'frame_label/{base_name}_predict.npy', frame_predict)
            # np.save(f'frame_label/{base_name}_pdict.npy', ucf_pdict)
            # np.save(f'frame_label/{base_name}_gdict.npy', ucf_gdict)

            fpr, tpr, _ = roc_curve(frame_gt, frame_predict)
            auc_score = auc(fpr, tpr)
            print(f"AUC for {base_name}: {auc_score:.4f}")

            precision, recall, _ = precision_recall_curve(frame_gt, frame_predict)
            ap_score = auc(recall, precision)
            print(f"AP  for {base_name}: {ap_score:.4f}")

