import pdb
import numpy as np
import torch.utils.data as data
import utils
from options import *
from config import *
from model import *
from uncertainty_utils import evaluate_prediction

import os
import time
import glob
# from dataset_loader import *
from dataset_loader import *



if __name__ == "__main__":
    args = parse_args()
    if args.debug:
        pdb.set_trace()

    config = Config(args)
    worker_init_fn = None
    gpus = [0]
    torch.cuda.set_device('cuda:{}'.format(gpus[0]))
    step = 3000
    if config.seed >= 0:
        utils.set_seed(config.seed)
        worker_init_fn = np.random.seed(config.seed)
    seed = 234
    config.len_feature = 1024
    net = WSAD(config.len_feature, flag = "Train", a_nums = 60, n_nums = 60)
    net = net.cuda()
    ensemble_path = os.path.join(config.ensemble_path, config.ensemble_run)
    os.makedirs(ensemble_path, exist_ok=True)
    abnormal_train_loader_for_uncertainty = data.DataLoader(
        UCF_crime(root_dir = config.root_dir, mode = 'Train', modal = config.modal, num_segments = 200, len_feature = config.len_feature, is_normal=False),
            batch_size = 1,
            shuffle = False, num_workers = config.num_workers,
            worker_init_fn = worker_init_fn)
    evaluate_prediction(net, config, abnormal_train_loader_for_uncertainty, ensemble_path, config.seed, config.run, step, model_file="models/ucf_new_code_memory_un_th_test_0_725.pkl.pkl")

