

import os.path as osp
from rekognition_online_action_detection.utils.parser import load_cfg
from rekognition_online_action_detection.utils.env import setup_environment
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.logger import setup_logger
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.datasets_test import CILSetTask
from rekognition_online_action_detection.engines import do_inference
import yaml, pickle
import random

random.seed(10)


def main(cfg):
    video_path = cfg['DATASET']['VIDEO_ROOT']
    conf_model = cfg['MODEL']
    


    with open(video_path, 'rb') as handle4:
        video_data = pickle.load(handle4)

    num_class = len(video_data['train'][0].keys())

    
    type_sampling = cfg['MEMORY']['TYPE_MEM']
    conf_model['type_sampling'] = type_sampling
    print('sampling strategy:', type_sampling)

    
    device = setup_environment(cfg)
    model = build_model(cfg, device)

    
    optimizer = build_optimizer(cfg, model)

    memory_size = cfg['MEMORY']['MEMORY_SIZE']


    batch_size = cfg.DATA_LOADER.BATCH_SIZE
    num_workers = cfg.DATA_LOADER.NUM_WORKERS
    pin_memory = cfg.DATA_LOADER.PIN_MEMORY
    model_name = cfg.MODEL.MODEL_NAME
    data_name = cfg.DATA.DATA_NAME
    data_root = cfg.DATA.DATA_ROOT
    visual_feature = cfg.INPUT.VISUAL_FEATURE
    motion_feature = cfg.INPUT.MOTION_FEATURE
    target_perframe = cfg.INPUT.TARGET_PERFRAME
    long_memory_length = cfg.MODEL.LSTR.LONG_MEMORY_LENGTH
    long_memory_sample_rate = cfg.MODEL.LSTR.LONG_MEMORY_SAMPLE_RATE
    long_memory_num_samples = cfg.MODEL.LSTR.LONG_MEMORY_NUM_SAMPLES
    work_memory_length = cfg.MODEL.LSTR.WORK_MEMORY_LENGTH
    work_memory_sample_rate = cfg.MODEL.LSTR.WORK_MEMORY_SAMPLE_RATE
    work_memory_num_samples = cfg.MODEL.LSTR.WORK_MEMORY_NUM_SAMPLES
    phases = cfg.SOLVER.PHASES
    epochs = cfg.SOLVER.NUM_EPOCHS
    
    test_cilDatasetList = CILSetTask(video_data['test'],
                                     memory_size,
                                     batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                     pin_memory=pin_memory,
                                     model_name=model_name, data_name=data_name, data_root=data_root,
                                     visual_feature=visual_feature, motion_feature=motion_feature,
                                     target_perframe=target_perframe,
                                     long_memory_length=long_memory_length,
                                     long_memory_sample_rate=long_memory_sample_rate,
                                     long_memory_num_samples=long_memory_num_samples,
                                     work_memory_length=work_memory_length,
                                     work_memory_sample_rate=work_memory_sample_rate,
                                     work_memory_num_samples=work_memory_num_samples, phase="test", phases=phases,
                                     train_enable=True, tag='BatchInference')  

    
    tet_loop(
        cfg,
        test_cilDatasetList,
        model,
        optimizer,
        epochs
    )


def tet_loop(cfg, tet_cilDatasetList, model, optimizer, epochs):
    device = setup_environment(cfg)
    iter_tetDataloader = iter(tet_cilDatasetList)  
    checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    num_tasks = tet_cilDatasetList.num_tasks
    logger = setup_logger(cfg, phase='test')
    
    for j in range(num_tasks):
        checkpointer_task_root = osp.join(checkpointer_root, 'task_' + str(j) + "_best.pth")
        checkpointer = setup_checkpointer(cfg, checkpointer_task_root, phase='test')
        checkpointer.load(model, optimizer)
        
        data_loaders_i, num_next_classes = next(iter_tetDataloader)
        
        do_inference(
            cfg,
            data_loaders_i,
            model,
            device,
            logger,
            j,
            inferr=1
        )


if __name__ == '__main__':
    main(load_cfg())
