


from rekognition_online_action_detection.utils.parser import load_cfg
from rekognition_online_action_detection.utils.env import setup_environment
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.logger import setup_logger
from rekognition_online_action_detection.datasets import build_data_loader
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.criterions import build_criterion
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.optimizers import build_scheduler
from rekognition_online_action_detection.engines import do_train
from rekognition_online_action_detection.datasets.cls_dataset import CILSetTask as CILSetTask
from rekognition_online_action_detection.datasets.downsample_cls_dataset import CILSetTask as CILSetTask_down_sampling
from rekognition_online_action_detection.datasets_test import CILSetTask as CILSetTask_tet
import torch
import argparse
import yaml, pickle
import torch.nn as nn
import os
import random
import os.path as osp
from rekognition_online_action_detection.engines import do_inference
random.seed(10)




def main(cfg):

    video_path=cfg['DATASET']['VIDEO_ROOT']
    conf_model = cfg['MODEL']
    


    with open(video_path, 'rb') as handle4:
        video_data = pickle.load(handle4)

    num_class = len(video_data['train'][0].keys())


    
    type_sampling = cfg['MEMORY']['TYPE_MEM']
    conf_model['type_sampling'] = type_sampling
    print('sampling strategy:', type_sampling)

    
    device = setup_environment(cfg)
    model = build_model(cfg, device)
    memory_size = cfg['MEMORY']['MEMORY_SIZE']
    num_sel_frames = cfg['MEMORY']['NUM_SEL_FRAMES']
    
    optimizer = build_optimizer(cfg, model)

    memory_size = cfg['MEMORY']['MEMORY_SIZE']


    checkpointer = setup_checkpointer(cfg, phase='train',checkpointer_task_root=cfg.MODEL.CHECKPOINT)
    logger = setup_logger(cfg, phase='train')

    batch_size = cfg.DATA_LOADER.BATCH_SIZE
    num_workers = cfg.DATA_LOADER.NUM_WORKERS
    pin_memory = cfg.DATA_LOADER.PIN_MEMORY
    model_name = cfg.MODEL.MODEL_NAME
    data_name = cfg.DATA.DATA_NAME
    data_root = cfg.DATA.DATA_ROOT
    visual_feature = cfg.INPUT.VISUAL_FEATURE
    motion_feature = cfg.INPUT.MOTION_FEATURE
    target_perframe = cfg.INPUT.TARGET_PERFRAME
    long_memory_length = cfg.MODEL.LSTR.LONG_MEMORY_LENGTH
    long_memory_sample_rate = cfg.MODEL.LSTR.LONG_MEMORY_SAMPLE_RATE
    long_memory_num_samples = cfg.MODEL.LSTR.LONG_MEMORY_NUM_SAMPLES
    work_memory_length = cfg.MODEL.LSTR.WORK_MEMORY_LENGTH
    work_memory_sample_rate = cfg.MODEL.LSTR.WORK_MEMORY_SAMPLE_RATE
    work_memory_num_samples = cfg.MODEL.LSTR.WORK_MEMORY_NUM_SAMPLES
    phases=cfg.SOLVER.PHASES
    
    train_cilDatasetList = CILSetTask(video_data['train'], memory_size,
                                      batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,
                                      model_name=model_name,data_name=data_name,data_root=data_root,
                                      visual_feature=visual_feature,motion_feature=motion_feature,target_perframe=target_perframe,
                                      long_memory_length=long_memory_length,
                                      long_memory_sample_rate=long_memory_sample_rate,long_memory_num_samples=long_memory_num_samples,
                                      work_memory_length=work_memory_length,work_memory_sample_rate=work_memory_sample_rate,
                                      work_memory_num_samples=work_memory_num_samples,phase="train",phases=phases,num_sel_frames=num_sel_frames,train_enable=True)  
    val_cilDatasetList = CILSetTask(video_data['train'], memory_size,
                                      batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                      pin_memory=pin_memory,
                                      model_name=model_name, data_name=data_name, data_root=data_root,
                                      visual_feature=visual_feature, motion_feature=motion_feature,
                                      target_perframe=target_perframe,
                                      long_memory_length=long_memory_length,
                                      long_memory_sample_rate=long_memory_sample_rate,
                                      long_memory_num_samples=long_memory_num_samples,
                                      work_memory_length=work_memory_length,
                                      work_memory_sample_rate=work_memory_sample_rate,
                                      work_memory_num_samples=work_memory_num_samples, phase="train", phases=phases,num_sel_frames=num_sel_frames,
                                      train_enable=True)  
    test_cilDatasetList = CILSetTask_tet(video_data['test'],
                                     memory_size,
                                     batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                     pin_memory=pin_memory,
                                     model_name=model_name, data_name=data_name, data_root=data_root,
                                     visual_feature=visual_feature, motion_feature=motion_feature,
                                     target_perframe=target_perframe,
                                     long_memory_length=long_memory_length,
                                     long_memory_sample_rate=long_memory_sample_rate,
                                     long_memory_num_samples=long_memory_num_samples,
                                     work_memory_length=work_memory_length,
                                     work_memory_sample_rate=work_memory_sample_rate,
                                     work_memory_num_samples=work_memory_num_samples, phase="test", phases=phases,
                                     train_enable=True, tag='BatchInference')  

    down_samplecilDatasetList=CILSetTask_down_sampling(video_data['train'], memory_size,
                                      batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,
                                         model_name=model_name,data_name=data_name,data_root=data_root,
                                      visual_feature=visual_feature,motion_feature=motion_feature,target_perframe=target_perframe,
                                      long_memory_length=long_memory_length,
                                      long_memory_sample_rate=long_memory_sample_rate,long_memory_num_samples=long_memory_num_samples,
                                      work_memory_length=work_memory_length,work_memory_sample_rate=work_memory_sample_rate,
                                      work_memory_num_samples=work_memory_num_samples,phase="train",phases=phases,num_sel_frames=num_sel_frames,train_enable=True)

    
    criterion = build_criterion(cfg, device)
    checkpointer.load(model, optimizer)
    train_loop(
        cfg,
        train_cilDatasetList,
        test_cilDatasetList,
        val_cilDatasetList,
        down_samplecilDatasetList,
        model,
        criterion,
        optimizer,
        device,
        checkpointer,
        logger,
        type_sampling,
        memory_size,
        num_sel_frames
    )



def train_loop( cfg,
        train_cilDatasetList,
        tet_cilDatasetList,
        val_cilDatasetList,
        down_samplecilDatasetList,
        model,
        criterion,
        optimizer,
        device,
        checkpointer,
        logger,
        type_sampling,
        memory_size,
        num_sel_frames
        ):
    memory={}
    batch_size = cfg.DATA_LOADER.BATCH_SIZE
    memory_video_ratio = cfg['MEMORY']['MEMORY_VIDEO_RATIO']
    memory_frame_ratio = cfg['MEMORY']['MEMORY_FRAME_RATIO']
    iter_trainDataloader = iter(train_cilDatasetList)  
    iter_tetDataloader = iter(tet_cilDatasetList)  
    iter_valDataloader = iter(val_cilDatasetList)
    iter_down_sampleDataloader = iter(down_samplecilDatasetList)
    num_tasks = tet_cilDatasetList.num_tasks
    m=memory_size//model.classifier.out_features//(32+4)
    checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    
    for j in range(num_tasks):
        
        tet_data_loaders_i, tet_num_next_classes = next(iter_tetDataloader)
        
        data_loaders_i,num_next_classes,video_data = next(iter_trainDataloader)
        val_data_loaders_i, num_next_classes, video_data = next(iter_valDataloader)
        down_sample_data_loaders_i,num_next_classes, video_data=next(iter_down_sampleDataloader)
        scheduler = build_scheduler(
            cfg, optimizer, len(data_loaders_i['train']))
        
        do_train(
            cfg,
            batch_size,
            memory_video_ratio,
            memory_frame_ratio,
            data_loaders_i,
            tet_data_loaders_i,
            model,
            criterion,
            optimizer,
            scheduler,
            device,
            checkpointer,
            logger,
            j,
            m,
            get_memory=False
        )
        checkpointer_task_root = osp.join(checkpointer_root, 'task_' + str(j) + "_best.pth")
        checkpointer = setup_checkpointer(cfg, phase='train',checkpointer_task_root=checkpointer_task_root)
        checkpointer.load(model, optimizer)
        selected_features=do_train(
            cfg,
            batch_size,
            memory_video_ratio,
            memory_frame_ratio,
            down_sample_data_loaders_i,
            tet_data_loaders_i,
            model,
            criterion,
            optimizer,
            scheduler,
            device,
            checkpointer,
            logger,
            j,
            m,
            get_memory=True
        )
        task_name='task_'+str(j)
        memory[task_name] = selected_features
        train_cilDatasetList.memory=memory
        print(1)






if __name__ == '__main__':
    main(load_cfg())
