


import time
from tqdm import tqdm
from rekognition_online_action_detection.utils.env import setup_environment
import torch
import torch.nn as nn
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.evaluation import compute_result
from rekognition_online_action_detection.engines import do_inference
import os.path as osp
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.icarl import icarl
from rekognition_online_action_detection.utils.utils import norm
import rekognition_online_action_detection.utils.utils as us
import numpy as np
from rekognition_online_action_detection.models.temporal_head import build_temporal_head
import os
def norm_list(lst):
    total_sum=sum(lst)
    normalized_lst = [float(x) / total_sum for x in lst]
    return normalized_lst



def do_perframe_det_train(cfg,
                          batch_size,
                          memory_video_ratio,
                          memory_frame_ratio,
                          data_loaders,
                          tet_dataloaders,
                          model,
                          criterion,
                          optimizer,
                          scheduler,
                          device,
                          checkpointer,
                          logger,
                          num_tasks,
                          m,
                          get_memory
                          ):
    video_memory=[]
    exemplar_index=[]
    alpha=0.95
    beta=0.45
    videos=data_loaders['train'].dataset.video_data['task_'+str(num_tasks)]
    tet_checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    best_meanAP=0
    tet_model = build_model(cfg, device)
    tet_optimizer = build_optimizer(cfg, tet_model)
    tet_device = setup_environment(cfg)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    parm_now=np.zeros(num_tasks)
    parm_now_work=np.zeros(num_tasks)
    if get_memory==False:
        multihead_attention = nn.MultiheadAttention(embed_dim=1024, num_heads=1)
        work_multihead_attention = nn.MultiheadAttention(embed_dim=1024, num_heads=1)
        task_head=build_temporal_head(cfg)
        work_task_head = build_temporal_head(cfg)
        for param in multihead_attention.parameters():
            param.requires_grad = True  
        for param in task_head.parameters():
            param.requires_grad = True  
        for param in work_multihead_attention.parameters():
            param.requires_grad = True  
        for param in work_task_head.parameters():
            param.requires_grad = True  
        multihead_attention.to(device)
        task_head.to(device)
        work_multihead_attention.to(device)
        work_task_head.to(device)
        model.multihead_attention_list.append(multihead_attention)
        model.task_head.append(task_head)
        model.work_multihead_attention_list.append(work_multihead_attention)
        model.work_task_head.append(work_task_head)
        if len(model.multihead_attention_list) == 1:
            model.multihead_attention_weights.append(1)
        else:
            attention_weights = model.acc.copy()
            attention_weights.append(0.5)
            model.multihead_attention_weights = norm_list(attention_weights)
        for epoch in range(cfg.SOLVER.START_EPOCH, cfg.SOLVER.START_EPOCH + cfg.SOLVER.NUM_EPOCHS):
            
            det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
            det_pred_scores = []
            det_gt_targets = []
            num_parm=0
            num_work_parm=0
            for param in model.task_head[-1].parameters():
                param.requires_grad = False
                parm_now[num_parm] = task_head[num_parm].weight
                num_parm = num_parm + 1
            for param in model.work_task_head[-1].parameters():
                param.requires_grad = False
                parm_now_work[num_work_parm] = task_head[num_work_parm].weight
                num_work_parm = num_work_parm + 1
            start = time.time()
            for phase in cfg.SOLVER.PHASES:
                training = phase == 'train'
                model.train(training)

                with torch.set_grad_enabled(training):
                    pbar = tqdm(data_loaders[phase],
                                desc='task{}:'.format(num_tasks)+'{}ing epoch {}'.format(phase.capitalize(), epoch))
                    for batch_idx, data in enumerate(pbar, start=1):
                        batch_size = data[0].shape[0]  
                        det_target = data[-1].to(device)

                        det_score = model(*[x.to(device) for x in data[:-1]])
                        det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES) 
                        det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)
                        det_loss_0 = criterion['MCE'](det_score, det_target)
                        det_losses[phase] += det_loss.item() * batch_size
                        det_score_1 = model(*[x.to(device) for x in data[:-1]])
                        det_score_2=np.linalg.norm(model.param_wgt- parm_now)+np.linalg.norm(model.work_param_wgt- parm_now_work)
                        det_loss=det_loss_0+det_score_1*alpha+det_score_2*beta
                        
                        pbar.set_postfix({
                            'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                            'det_loss': '{:.5f}'.format(det_loss.item()),
                        })

                        if training:
                            optimizer.zero_grad()
                            det_loss.backward()
                            optimizer.step()
                            scheduler.step()
                        else:
                            
                            det_score = det_score.softmax(dim=1).cpu().tolist()
                            det_target = det_target.cpu().tolist()
                            det_pred_scores.extend(det_score)
                            det_gt_targets.extend(det_target)
            end = time.time()
            
            log = []
            log.append('Epoch {:2}'.format(epoch))
            log.append('train det_loss: {:.5f}'.format(
                det_losses['train'] / len(data_loaders['train'].dataset),
            ))
            if 'test' in cfg.SOLVER.PHASES:
                
                det_result = compute_result['perframe'](
                    cfg,
                    det_gt_targets,
                    det_pred_scores,
                )
                log.append('test det_loss: {:.5f} det_mAP: {:.5f}'.format(
                    det_losses['test'] / len(data_loaders['test'].dataset),
                    det_result['mean_AP'],
                ))
            log.append('running time: {:.2f} sec'.format(
                end - start,
            ))
            logger.info(' | '.join(log))
            
            if epoch==cfg.SOLVER.NUM_EPOCHS:
                checkpointer.save(epoch, model, optimizer,num_tasks)
                tet_checkpointer_task_root = osp.join(tet_checkpointer_root,
                                                      'task_' + str(num_tasks) + "_epoch-" + str(epoch) + ".pth")
                tet_checkpointer = setup_checkpointer(cfg, tet_checkpointer_task_root, phase='test')
                tet_checkpointer.load(tet_model, tet_optimizer)
                tet_model = tet_model.to(tet_device)
                for multihead_attention in tet_model.multihead_attention_list:
                    multihead_attention = multihead_attention.to(tet_device)
                mean_AP=do_inference(
                    cfg,
                    tet_dataloaders,
                    tet_model,
                    tet_device,
                    logger,
                    num_tasks,
                    inferr=0
                )
                checkpointer.save_best(epoch, tet_model, tet_optimizer, num_tasks)
                best_meanAP=mean_AP
                os.remove(tet_checkpointer_task_root)
            
            data_loaders['train'].dataset.shuffle()
        model.acc.append(best_meanAP)
        for param in model.multihead_attention_list[-1].parameters():
            param.requires_grad = False
        for param in model.work_multihead_attention_list[-1].parameters():
            param.requires_grad = False
        num_parm=0
        num_work_parm=0
        for param in model.task_head[-1].parameters():
            param.requires_grad = False
            model.param_wgt[num_parm] = task_head[num_parm].weight
            num_parm=num_parm+1
        for param in model.work_task_head[-1].parameters():
            param.requires_grad = False
            model.work_param_wgt[num_work_parm] = task_head[num_work_parm].weight
            num_work_parm=num_work_parm+1



    else:
        
        det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
        det_pred_scores = []
        det_gt_targets = []
        task_back_scores=[]
        task_action_scores=[]
        task_all_visual_features=[]
        task_all_flow_features = []
        task_all_target=[]
        task_back_index = []
        task_action_index = []
        video_data=[]
        task_selected_features={}

        start = time.time()
        phase='test'
        training = phase == 'train'
        model.train(training)
        with torch.set_grad_enabled(training):
            pbar = tqdm(data_loaders['train'],
                        desc='task{}:'.format(num_tasks) + ': get memory')
            for batch_idx, data in enumerate(pbar, start=1):
                bs = data[0].shape[0]  
                det_target = data[-1].to(device)
                det_score = model(*[x.to(device) for x in data[:-1]])
                for i in range(bs):
                    video_memory.append(np.array(det_score[i].cpu().detach().numpy()).reshape(22*32))  
                    data_new=[]
                    data_new.append(data_loaders["train"].dataset.inputs[(batch_idx-1)*batch_size+i][0]) 
                    data_new.append(data[0][i])
                    data_new.append(data[1][i])
                    data_new.append(data[2][i])
                    data_new.append(data[3][i])
                    data_new.append(np.array(det_score[i].cpu().detach().numpy()))
                    video_data.append(data_new)
        video_now_name="new"
        video_num_batch = 0
        for i in range(len(video_data)):
            if video_now_name!=video_data[i][0]:    
                if video_now_name!="new":
                    video_num_batch=0  
                    task_back_scores.append(video_back_scores)
                    task_action_scores.append(video_action_scores)
                    task_all_visual_features.append(video_all_visual_features)
                    task_all_flow_features.append(video_all_flow_features)
                    task_all_target.append(video_all_target)
                    task_back_index.append(video_back_index)
                    task_action_index.append(video_action_index)
                video_now_name= video_data[i][0]
                video_back_scores = []
                video_action_scores = []
                video_all_visual_features=[]
                video_all_flow_features = []
                video_back_index = []
                video_action_index = []
                video_all_target=[]
            for j in range(video_data[i][4].shape[0]):
                work_all_length=video_data[i][1].shape[0]
                work_target_length = video_data[i][4].shape[0]
                video_all_visual_features.append(np.array(video_data[i][1][j+work_all_length-work_target_length].cpu().detach().numpy()))
                video_all_flow_features.append(
                    np.array(video_data[i][2][j + work_all_length - work_target_length].cpu().detach().numpy()))
                video_all_target.append(np.array(video_data[i][4][j].cpu().detach().numpy()))
                if video_data[i][4][j][0]==1:   
                    video_back_scores.append(video_data[i][5][j])
                    video_back_index.append(video_num_batch*video_data[i][4].shape[0]+j)   

                else:
                    video_action_scores.append(video_data[i][5][j])      
                    video_action_index.append(video_num_batch*video_data[i][4].shape[0]+j)
            video_num_batch =video_num_batch+1
        task_back_scores.append(video_back_scores)
        task_action_scores.append(video_action_scores)
        task_all_visual_features.append(video_all_visual_features)
        task_all_flow_features.append(video_all_flow_features)
        task_all_target.append(video_all_target)
        task_back_index.append(video_back_index)
        task_action_index.append(video_action_index)

        for i in range(len(task_back_scores)):
            video_back_scores=task_back_scores[i]  
            video_action_scores = task_action_scores[i]
            video_back_index = task_back_index[i]
            video_action_index = task_action_index[i]
            video_all_visual_features = task_all_visual_features[i]
            video_all_flow_features = task_all_flow_features[i]
            video_all_target=task_all_target[i]




            video_back_scores = np.array(video_back_scores)
            video_action_scores = np.array(video_action_scores)
            num_back_save=2*video_back_scores.shape[0]*memory_frame_ratio
            num_action_save =2*video_action_scores.shape[0]*memory_frame_ratio
            back_save_idx=icarl(video_back_scores,num_back_save,video_back_index) 
            back_visual_save=[]
            back_flow_save = []
            back_target_save=[]
            for j in range(len(back_save_idx)):
                back_visual_save.append(video_all_visual_features[back_save_idx[j]])
                back_flow_save.append(video_all_flow_features[back_save_idx[j]])
                back_target_save.append(video_all_target[back_save_idx[j]])
            if video_action_scores.size>0:
                action_save_idx=icarl(video_action_scores,num_action_save,video_action_index)
                action_visual_save = []
                action_flow_save = []
                action_target_save=[]
                for j in range(len(action_save_idx)):
                    action_visual_save.append(video_all_visual_features[action_save_idx[j]])
                    action_flow_save.append(video_all_flow_features[action_save_idx[j]])
                    action_target_save.append(video_all_target[action_save_idx[j]])
                norm_visual_back=norm(back_visual_save)
                norm_flow_back = norm(back_flow_save)
                norm_visual_action=norm(action_visual_save)
                norm_flow_action = norm(action_flow_save)
                norm_back=np.concatenate((norm_visual_back,norm_flow_back),axis=1)
                norm_action = np.concatenate((norm_visual_action, norm_flow_action), axis=1)
                selected_idx=us.sort_back_action(norm_back,norm_action,video_back_scores.shape[0]*memory_frame_ratio,video_action_scores.shape[0]*memory_frame_ratio,back_save_idx,action_save_idx)
                selected_visual_features=[]
                selected_flow_features=[]
                selected_all_features=[]
                selected_target=[]
                for j in range(len(selected_idx)):
                    selected_visual_features.append(video_all_visual_features[selected_idx[j]])
                    selected_flow_features.append(video_all_flow_features[selected_idx[j]])
                    selected_target.append(video_all_target[selected_idx[j]])

                selected_all_features.append(selected_visual_features)
                selected_all_features.append(selected_flow_features)
                selected_all_features.append(selected_target)
                task_selected_features[videos[i]]=selected_all_features
        return task_selected_features

