


import time
from tqdm import tqdm
from rekognition_online_action_detection.utils.env import setup_environment
import torch
import torch.nn as nn
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.evaluation import compute_result
from rekognition_online_action_detection.engines import do_inference
import os.path as osp
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
from rekognition_online_action_detection.utils.conloss import get_conloss
import numpy as np
import os




def do_perframe_det_train(cfg,
                          data_loaders,
                          tet_dataloaders,
                          model,
                          criterion,
                          optimizer,
                          scheduler,
                          device,
                          checkpointer,
                          logger,
                          num_tasks,
                          m,
                          get_memory
                          ):
    video_memory=[]
    exemplar_index=[]
    videos=data_loaders['train'].dataset.video_data['task_'+str(num_tasks)]
    tet_checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    best_meanAP=0
    tet_model = build_model(cfg, device)
    tet_optimizer = build_optimizer(cfg, tet_model)
    tet_device = setup_environment(cfg)
    
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)
    if get_memory==False:
        for epoch in range(cfg.SOLVER.START_EPOCH, cfg.SOLVER.START_EPOCH + cfg.SOLVER.NUM_EPOCHS):
            
            det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
            det_pred_scores = []
            det_gt_targets = []

            start = time.time()
            for phase in cfg.SOLVER.PHASES:
                training = phase == 'train'
                model.train(training)

                with torch.set_grad_enabled(training):
                    pbar = tqdm(data_loaders[phase],
                                desc='task{}:'.format(num_tasks)+'{}ing epoch {}'.format(phase.capitalize(), epoch))
                    for batch_idx, datas in enumerate(pbar, start=1):
                        
                        data_size=len(datas)
                        data=datas[0:int(data_size/2)]
                        data_g=datas[int(data_size/2):data_size]
                        batch_size = data[0].shape[0]  
                        det_target = data[-1].to(device)
                        det_target_g = data_g[-1].to(device)
                        det_score = model(*[x.to(device) for x in data[:-1]])
                        det_score_g = model(*[x.to(device) for x in data_g[:-1]])
                        det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES) 
                        det_score_g = det_score_g.reshape(-1, cfg.DATA.NUM_CLASSES)  
                        det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)
                        det_target_g = det_target_g.reshape(-1, cfg.DATA.NUM_CLASSES)  
                        det_loss = criterion['MCE'](det_score, det_target)
                        det_loss_g=criterion['MCE'](det_score_g, det_target_g)
                        det_loss=get_conloss(det_loss,det_loss_g,det_score,det_score_g,cr_lambda=1,adv_lambda=0.5)
                        det_losses[phase] += det_loss.item() * batch_size

                        
                        pbar.set_postfix({
                            'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                            'det_loss': '{:.5f}'.format(det_loss.item()),
                        })

                        if training:
                            optimizer.zero_grad()
                            det_loss.backward()
                            optimizer.step()
                            scheduler.step()
                        else:
                            
                            det_score = det_score.softmax(dim=1).cpu().tolist()
                            det_target = det_target.cpu().tolist()
                            det_pred_scores.extend(det_score)
                            det_gt_targets.extend(det_target)
            end = time.time()
            
            log = []
            log.append('Epoch {:2}'.format(epoch))
            log.append('train det_loss: {:.5f}'.format(
                det_losses['train'] / len(data_loaders['train'].dataset),
            ))
            if 'test' in cfg.SOLVER.PHASES:
                
                det_result = compute_result['perframe'](
                    cfg,
                    det_gt_targets,
                    det_pred_scores,
                )
                log.append('test det_loss: {:.5f} det_mAP: {:.5f}'.format(
                    det_losses['test'] / len(data_loaders['test'].dataset),
                    det_result['mean_AP'],
                ))
            log.append('running time: {:.2f} sec'.format(
                end - start,
            ))
            logger.info(' | '.join(log))
            
            if epoch % 5==0 or epoch<=10:
                checkpointer.save(epoch, model, optimizer,num_tasks)
                tet_checkpointer_task_root = osp.join(tet_checkpointer_root,
                                                      'task_' + str(num_tasks) + "_epoch-" + str(epoch) + ".pth")
                tet_checkpointer = setup_checkpointer(cfg, tet_checkpointer_task_root, phase='test')
                tet_checkpointer.load(tet_model, tet_optimizer)
                mean_AP=do_inference(
                    cfg,
                    tet_dataloaders,
                    tet_model,
                    tet_device,
                    logger,
                    num_tasks,
                    inferr=0
                )
                if mean_AP>best_meanAP:
                    checkpointer.save_best(epoch, tet_model, tet_optimizer, num_tasks)
                    best_meanAP=mean_AP
            
            data_loaders['train'].dataset.shuffle()


    else:
        
        det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
        det_pred_scores = []
        det_gt_targets = []
        video_data=[]
        start = time.time()
        phase='test'
        training = phase == 'train'
        model.train(training)
        with torch.set_grad_enabled(training):
            pbar = tqdm(data_loaders['train'],
                        desc='task{}:'.format(num_tasks) + ': get memory')
            for batch_idx, data in enumerate(pbar, start=1):
                batch_size = data[0].shape[0]  
                det_target = data[-1].to(device)
                det_score = model(*[x.to(device) for x in data[:-1]])
                for i in range(batch_size):
                    video_memory.append(np.array(det_score[i].cpu().detach().numpy()).reshape(22*32))  
                    data_new=[]
                    data_new.append(data[0][i])
                    data_new.append(data[1][i])
                    data_new.append(data[2][i])
                    data_new.append(data[3][i])
                    video_data.append(data_new)

        video_memory=np.array(video_memory) 
        exemplar_set = []  
        exemplar_features = []  
        list_selected_idx = []  
        task_mean = np.mean(video_memory, axis=0)  
        num_store = 0
        while  num_store <= m:  
            S = np.sum(exemplar_features, axis=0)  
            phi = video_memory  
            mu = task_mean  
            mu_p = 1.0 / (num_store + 1) * (
                    phi + S)  
            mu_p = mu_p / np.linalg.norm(mu_p)  
            dist = np.sqrt(np.sum((mu - mu_p) ** 2,
                                    axis=1))  
            if num_store <= len(dist) - 2:  
                list_idx = np.argpartition(dist, num_store)[
                            :num_store + 1]  
                num_store = num_store + 1
            elif num_store < len(dist):  
                fixed_k = len(dist) - 2  
                list_idx = np.argpartition(dist, fixed_k)[:fixed_k + 2]  
                num_store = num_store + 1
            else:
                break
            for idx in list_idx:  
                if idx not in list_selected_idx:
                    list_selected_idx.append(idx)
                    exemplar_set.append(videos[idx])
                    exemplar_features.append(video_memory[idx])
                    break
        return exemplar_features,exemplar_set
