


import time
from tqdm import tqdm
from rekognition_online_action_detection.utils.env import setup_environment
import torch
import torch.nn as nn
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.evaluation import compute_result
from rekognition_online_action_detection.engines import do_inference
import os.path as osp
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.icarl import icarl
from rekognition_online_action_detection.utils.utils import norm
import rekognition_online_action_detection.utils.utils as us
import numpy as np
import os
class BiasLayer(nn.Module):
    def __init__(self):
        super(BiasLayer, self).__init__()
        self.alpha = nn.Parameter(torch.ones(1, requires_grad=True))
        self.beta = nn.Parameter(torch.zeros(1, requires_grad=True))
    def forward(self, x):
        return self.alpha * x + self.beta
    def printParam(self, i):
        print(i, self.alpha.item(), self.beta.item())
def train_step(cfg,
               data_loaders,
               tet_dataloaders,
               model,
               criterion,
               optimizer,
               bias_optimizer,
               scheduler,
               device,
               checkpointer,
               logger,
               num_tasks,
               m,
               get_memory,
               tet_optimizer,
               tet_device,
               tet_model,
               tet_checkpointer_root,
               stage_id
               ):
    best_meanAP = 0
    for epoch in range(cfg.SOLVER.START_EPOCH, cfg.SOLVER.START_EPOCH + cfg.SOLVER.NUM_EPOCHS):
        
        det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
        det_pred_scores = []
        det_gt_targets = []
        start = time.time()
        for phase in cfg.SOLVER.PHASES:
            training = phase == 'train'
            if stage_id==0 and training==1:
                model.pos_encoding.train()
                model.dec_modules.train()
                model.enc_modules.train()
                for bias_layer in model.list_bias_layers:
                    bias_layer.eval()
            if stage_id==0 and training==0:
                model.pos_encoding.eval()
                model.dec_modules.eval()
                model.enc_modules.eval()
                for bias_layer in model.list_bias_layers:
                    bias_layer.eval()
            if stage_id==1 and training==1:
                model.pos_encoding.eval()
                model.dec_modules.eval()
                model.enc_modules.eval()
                for bias_layer in model.list_bias_layers:
                    bias_layer.train()
            if stage_id==1 and training==0:
                model.pos_encoding.eval()
                model.dec_modules.eval()
                model.enc_modules.eval()
                for bias_layer in model.list_bias_layers:
                    bias_layer.eval()

            pbar = tqdm(data_loaders[phase],
                        desc='task{}:'.format(num_tasks) + 'stage{}:'.format(stage_id)+'{}ing epoch {}'.format(phase.capitalize(), epoch))
            for batch_idx, data in enumerate(pbar, start=1):
                batch_size = data[0].shape[0]  
                det_target = data[-1].to(device)

                det_score = model(*[x.to(device) for x in data[:-1]])
                det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES)  
                det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)  
                det_loss = criterion['MCE'](det_score, det_target)
                det_losses[phase] += det_loss.item() * batch_size

                    
                if stage_id==0:
                    pbar.set_postfix({
                        'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                        'det_loss': '{:.5f}'.format(det_loss.item()),
                    })
                else:
                    pbar.set_postfix({
                        'lr': '{:.7f}'.format(bias_optimizer.defaults['lr']),
                        'det_loss': '{:.5f}'.format(det_loss.item()),
                    })

                if training:
                    if stage_id==0:
                        optimizer.zero_grad()
                    else:
                        bias_optimizer.zero_grad()
                    det_loss.backward()
                    if stage_id==0:
                        optimizer.step()
                    else:
                        bias_optimizer.step()
                    if stage_id==0:
                        scheduler.step()
                else:
                        
                    det_score = det_score.softmax(dim=1).cpu().tolist()
                    det_target = det_target.cpu().tolist()
                    det_pred_scores.extend(det_score)
                    det_gt_targets.extend(det_target)
        end = time.time()
        
        log = []
        log.append('Task {:2}'.format(num_tasks))
        log.append('Stage{:2}'.format(stage_id))
        log.append('Epoch {:2}'.format(epoch))
        log.append('train det_loss: {:.5f}'.format(
            det_losses['train'] / len(data_loaders['train'].dataset),
        ))
        if 'test' in cfg.SOLVER.PHASES:
            
            det_result = compute_result['perframe'](
                cfg,
                det_gt_targets,
                det_pred_scores,
            )
            log.append('test det_loss: {:.5f} det_mAP: {:.5f}'.format(
                det_losses['test'] / len(data_loaders['test'].dataset),
                det_result['mean_AP'],
            ))
        log.append('running time: {:.2f} sec'.format(
            end - start,
        ))
        logger.info(' | '.join(log))
        
        if epoch==cfg.SOLVER.NUM_EPOCHS:
                checkpointer.save_best(epoch, model, optimizer, num_tasks)
        
        data_loaders['train'].dataset.shuffle()






def do_perframe_det_train(cfg,
                          batch_size,
                          memory_video_ratio,
                          memory_frame_ratio,
                          train_train_data_loaders_i,
                          train_val_data_loaders_i,
                          tet_dataloaders,
                          model,
                          criterion,
                          optimizer,
                          scheduler,
                          device,
                          checkpointer,
                          logger,
                          num_tasks,
                          m,
                          get_memory
                          ):
    video_memory=[]
    exemplar_index=[]
    videos=train_train_data_loaders_i['train'].dataset.video_data['task_'+str(num_tasks)]
    tet_checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    best_meanAP=0
    tet_model = build_model(cfg, device)
    tet_optimizer = build_optimizer(cfg, tet_model)
    tet_device = setup_environment(cfg)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if get_memory==False:
        if train_val_data_loaders_i==None:
            for param in model.pos_encoding.parameters():
                param.requires_grad=True
            for param in model.dec_modules.parameters():
                param.requires_grad=True
            for param in model.enc_modules.parameters():
                param.requires_grad=True
            bias_layer=BiasLayer()
            bias_layer.to(device)
            for param in bias_layer.parameters():
                param.requires_grad=False
            model.list_bias_layers.append(bias_layer)
            model.list_splits.append(model.n_classes)
            train_step(cfg,
                          train_train_data_loaders_i,
                          tet_dataloaders,
                          model,
                          criterion,
                          optimizer,
                          None,
                          scheduler,
                          device,
                          checkpointer,
                          logger,
                          num_tasks,
                          m,
                          get_memory,
                          tet_optimizer,
                          tet_device,
                          tet_model,
                          tet_checkpointer_root,
                          stage_id=0)
        else:
            list_dataloader = [train_train_data_loaders_i, train_val_data_loaders_i]
            bias_layer = BiasLayer()
            bias_layer.to(device)
            for param in bias_layer.parameters():
                param.requires_grad = False  
            model.list_bias_layers.append(bias_layer)
            model.list_splits.append(model.n_classes)  
            for id_phase, dataloader_cil in enumerate(list_dataloader):
                if id_phase == 0:  
                    for param in model.pos_encoding.parameters():
                        param.requires_grad = True
                    for param in model.dec_modules.parameters():
                        param.requires_grad = True
                    for param in model.enc_modules.parameters():
                        param.requires_grad = True
                    train_step(cfg,
                               dataloader_cil,
                               tet_dataloaders,
                               model,
                               criterion,
                               optimizer,
                               None,
                               scheduler,
                               device,
                               checkpointer,
                               logger,
                               num_tasks,
                               m,
                               get_memory,
                               tet_optimizer,
                               tet_device,
                               tet_model,
                               tet_checkpointer_root,
                               stage_id=0)
                else:
                    checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
                    checkpointer_task_root = osp.join(checkpointer_root, 'task_' + str(num_tasks) + "_best.pth")
                    checkpointer = setup_checkpointer(cfg, checkpointer_task_root, phase='train')
                    checkpointer.load(model, optimizer)
                    model=model.to(device)
                    for bias_layer in model.list_bias_layers:
                        bias_layer = bias_layer.to(device)
                    for param in model.pos_encoding.parameters():
                        param.requires_grad = False
                    for param in model.dec_modules.parameters():
                        param.requires_grad = False
                    for param in model.enc_modules.parameters():
                        param.requires_grad = False

                    bias_layer =model.list_bias_layers[-1]  
                    for param in bias_layer.parameters():
                        param.requires_grad = True
                    bias_optimizer = torch.optim.SGD(bias_layer.parameters(), lr=0.005)  
                    train_step(cfg,
                               dataloader_cil,
                               tet_dataloaders,
                               model,
                               criterion,
                               optimizer,
                               bias_optimizer,
                               scheduler,
                               device,
                               checkpointer,
                               logger,
                               num_tasks,
                               m,
                               get_memory,
                               tet_optimizer,
                               tet_device,
                               tet_model,
                               tet_checkpointer_root,
                               stage_id=1)
                    for i, bias_layer in enumerate(model.list_bias_layers):
                        bias_layer.printParam(i)
                    for param in model.list_bias_layers[-1].parameters():
                        param.requires_grad = False






    else:
        
        det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
        task_all_scores=[]
        task_all_visual_features=[]
        task_all_flow_features = []
        task_all_target=[]
        task_all_index = []
        video_data=[]
        task_selected_features={}
        start = time.time()
        phase='test'
        training = phase == 'train'
        model.train(training)
        with torch.set_grad_enabled(training):
            pbar = tqdm(train_train_data_loaders_i['train'],
                        desc='task{}:'.format(num_tasks) + ': get memory')
            for batch_idx, data in enumerate(pbar, start=1):
                bs = data[0].shape[0]  
                det_target = data[-1].to(device)
                det_score = model(*[x.to(device) for x in data[:-1]])
                for i in range(bs):
                    video_memory.append(np.array(det_score[i].cpu().detach().numpy()).reshape(22*32))  
                    data_new=[]
                    data_new.append(train_train_data_loaders_i["train"].dataset.inputs[(batch_idx-1)*batch_size+i][0]) 
                    data_new.append(data[0][i])
                    data_new.append(data[1][i])
                    data_new.append(data[2][i])
                    data_new.append(data[3][i])
                    data_new.append(np.array(det_score[i].cpu().detach().numpy()))
                    video_data.append(data_new)
        video_now_name="new"
        video_num_batch = 0
        for i in range(len(video_data)):
            if video_now_name!=video_data[i][0]:    
                if video_now_name!="new":
                    video_num_batch=0  
                    task_all_scores.append(video_all_scores)
                    task_all_visual_features.append(video_all_visual_features)
                    task_all_flow_features.append(video_all_flow_features)
                    task_all_target.append(video_all_target)
                    task_all_index.append(video_all_index)
                video_now_name= video_data[i][0]
                video_all_scores = []
                video_all_visual_features=[]
                video_all_flow_features = []
                video_all_index = []
                video_all_target=[]
            for j in range(video_data[i][4].shape[0]):
                work_all_length=video_data[i][1].shape[0]
                work_target_length = video_data[i][4].shape[0]
                video_all_visual_features.append(np.array(video_data[i][1][j+work_all_length-work_target_length].cpu().detach().numpy()))
                video_all_flow_features.append(
                    np.array(video_data[i][2][j + work_all_length - work_target_length].cpu().detach().numpy()))
                video_all_target.append(np.array(video_data[i][4][j].cpu().detach().numpy()))
                video_all_scores.append(video_data[i][5][j])
                video_all_index.append(video_num_batch*video_data[i][4].shape[0]+j)
            video_num_batch =video_num_batch+1
        task_all_scores.append(video_all_scores)
        task_all_visual_features.append(video_all_visual_features)
        task_all_flow_features.append(video_all_flow_features)
        task_all_target.append(video_all_target)
        task_all_index.append(video_all_index)

        for i in range(len(task_all_scores)):
            video_all_scores=task_all_scores[i]  
            video_all_index = task_all_index[i]
            video_all_visual_features = task_all_visual_features[i]
            video_all_flow_features = task_all_flow_features[i]
            video_all_target=task_all_target[i]
            video_all_scores = np.array(video_all_scores)
            num_all_save=video_all_scores.shape[0]*memory_frame_ratio
            all_save_idx=icarl(video_all_scores,num_all_save,video_all_index) 
            all_visual_save=[]
            all_flow_save = []
            all_target_save=[]
            for j in range(len(all_save_idx)):
                all_visual_save.append(video_all_visual_features[all_save_idx[j]])
                all_flow_save.append(video_all_flow_features[all_save_idx[j]])
                all_target_save.append(video_all_target[all_save_idx[j]])
            selected_all_features=[]
            selected_all_features.append(all_visual_save)
            selected_all_features.append(all_flow_save)
            selected_all_features.append(all_target_save)
            task_selected_features[videos[i]]=selected_all_features
        return task_selected_features

