


from rekognition_online_action_detection.utils.parser import load_cfg
from rekognition_online_action_detection.utils.env import setup_environment
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.logger import setup_logger
from rekognition_online_action_detection.datasets import build_data_loader
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.criterions import build_criterion
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.optimizers import build_scheduler
from rekognition_online_action_detection.engines import do_train
from rekognition_online_action_detection.datasets import CILSetTask
from rekognition_online_action_detection.datasets_test import CILSetTask as CILSetTask_tet
from rekognition_online_action_detection.utils.EWC import get_regularized_loss
from rekognition_online_action_detection.utils.EWC import on_task_update
from comet_ml import Experiment
import torch
import argparse
import yaml, pickle
import torch.nn as nn
import os
import random
import os.path as osp
from rekognition_online_action_detection.engines import do_inference
random.seed(10)


def main(cfg):

    video_path=cfg['DATASET']['VIDEO_ROOT']
    conf_model = cfg['MODEL']
    


    with open(video_path, 'rb') as handle4:
        video_data = pickle.load(handle4)

    num_class = len(video_data['train'][0].keys())


    
    type_sampling = cfg['MEMORY']['TYPE_MEM']
    conf_model['type_sampling'] = type_sampling
    print('sampling strategy:', type_sampling)

    
    device = setup_environment(cfg)
    model = build_model(cfg, device)

    
    optimizer = build_optimizer(cfg, model)

    memory_size = cfg['MEMORY']['MEMORY_SIZE']


    checkpointer = setup_checkpointer(cfg, phase='train',checkpointer_task_root=cfg.MODEL.CHECKPOINT)
    logger = setup_logger(cfg, phase='train')

    batch_size = cfg.DATA_LOADER.BATCH_SIZE
    num_workers = cfg.DATA_LOADER.NUM_WORKERS
    pin_memory = cfg.DATA_LOADER.PIN_MEMORY
    model_name = cfg.MODEL.MODEL_NAME
    data_name = cfg.DATA.DATA_NAME
    data_root = cfg.DATA.DATA_ROOT
    visual_feature = cfg.INPUT.VISUAL_FEATURE
    motion_feature = cfg.INPUT.MOTION_FEATURE
    target_perframe = cfg.INPUT.TARGET_PERFRAME
    long_memory_length = cfg.MODEL.LSTR.LONG_MEMORY_LENGTH
    long_memory_sample_rate = cfg.MODEL.LSTR.LONG_MEMORY_SAMPLE_RATE
    long_memory_num_samples = cfg.MODEL.LSTR.LONG_MEMORY_NUM_SAMPLES
    work_memory_length = cfg.MODEL.LSTR.WORK_MEMORY_LENGTH
    work_memory_sample_rate = cfg.MODEL.LSTR.WORK_MEMORY_SAMPLE_RATE
    work_memory_num_samples = cfg.MODEL.LSTR.WORK_MEMORY_NUM_SAMPLES
    phases=cfg.SOLVER.PHASES
    
    train_cilDatasetList = CILSetTask(video_data['train'], memory_size,
                                      batch_size=batch_size,shuffle=True,num_workers=num_workers,pin_memory=pin_memory,
                                      model_name=model_name,data_name=data_name,data_root=data_root,
                                      visual_feature=visual_feature,motion_feature=motion_feature,target_perframe=target_perframe,
                                      long_memory_length=long_memory_length,
                                      long_memory_sample_rate=long_memory_sample_rate,long_memory_num_samples=long_memory_num_samples,
                                      work_memory_length=work_memory_length,work_memory_sample_rate=work_memory_sample_rate,
                                      work_memory_num_samples=work_memory_num_samples,phase="train",phases=phases,train_enable=True)  

    test_cilDatasetList = CILSetTask_tet(video_data['test'],
                                     memory_size,
                                     batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                     pin_memory=pin_memory,
                                     model_name=model_name, data_name=data_name, data_root=data_root,
                                     visual_feature=visual_feature, motion_feature=motion_feature,
                                     target_perframe=target_perframe,
                                     long_memory_length=long_memory_length,
                                     long_memory_sample_rate=long_memory_sample_rate,
                                     long_memory_num_samples=long_memory_num_samples,
                                     work_memory_length=work_memory_length,
                                     work_memory_sample_rate=work_memory_sample_rate,
                                     work_memory_num_samples=work_memory_num_samples, phase="test", phases=phases,
                                     train_enable=True, tag='BatchInference')  


    
    criterion = build_criterion(cfg, device)
    checkpointer.load(model, optimizer)
    train_loop(
        cfg,
        train_cilDatasetList,
        test_cilDatasetList,
        model,
        criterion,
        optimizer,
        device,
        checkpointer,
        logger,
    )



def train_loop( cfg,
        train_cilDatasetList,
        tet_cilDatasetList,
        model,
        criterion,
        optimizer,
        device,
        checkpointer,
        logger,
        ):
    iter_trainDataloader = iter(train_cilDatasetList)  
    iter_tetDataloader = iter(tet_cilDatasetList)  
    num_tasks = tet_cilDatasetList.num_tasks
    checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    
    for j in range(num_tasks):
        
        tet_data_loaders_i, tet_num_next_classes = next(iter_tetDataloader)
        
        data_loaders_i,num_next_classes = next(iter_trainDataloader)
        scheduler = build_scheduler(
            cfg, optimizer, len(data_loaders_i['train']))
        
        do_train(
            cfg,
            data_loaders_i,
            tet_data_loaders_i,
            model,
            criterion,
            optimizer,
            scheduler,
            device,
            checkpointer,
            logger,
            j
        )
        checkpointer_task_root = osp.join(checkpointer_root, 'task_' + str(j) + "_best.pth")
        checkpointer = setup_checkpointer(cfg, phase='train',checkpointer_task_root=checkpointer_task_root)
        checkpointer.load(model, optimizer)
        model.reg_params = on_task_update(cfg,data_loaders_i, device, optimizer, model,j,criterion,scheduler)




if __name__ == '__main__':
    main(load_cfg())
