# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import time
from tqdm import tqdm
from rekognition_online_action_detection.utils.env import setup_environment
import torch
import torch.nn as nn
from rekognition_online_action_detection.models import build_model
from rekognition_online_action_detection.optimizers import build_optimizer
from rekognition_online_action_detection.evaluation import compute_result
from rekognition_online_action_detection.engines import do_inference
import os.path as osp
from rekognition_online_action_detection.utils.checkpointer import setup_checkpointer
def do_perframe_det_train(cfg,
                          data_loaders,
                          tet_dataloaders,
                          model,
                          criterion,
                          optimizer,
                          scheduler,
                          device,
                          checkpointer,
                          logger,
                          num_tasks):
    tet_checkpointer_root = 'checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    best_meanAP=0
    tet_model = build_model(cfg, device)
    tet_optimizer = build_optimizer(cfg, tet_model)
    tet_device = setup_environment(cfg)
    # Setup model on multiple GPUs
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model)

    for epoch in range(cfg.SOLVER.START_EPOCH, cfg.SOLVER.START_EPOCH + cfg.SOLVER.NUM_EPOCHS):
        # Reset
        det_losses = {phase: 0.0 for phase in cfg.SOLVER.PHASES}
        det_pred_scores = []
        det_gt_targets = []

        start = time.time()
        for phase in cfg.SOLVER.PHASES:
            training = phase == 'train'
            model.train(training)

            with torch.set_grad_enabled(training):
                pbar = tqdm(data_loaders[phase],
                            desc='task{}:'.format(num_tasks)+'{}ing epoch {}'.format(phase.capitalize(), epoch))
                for batch_idx, data in enumerate(pbar, start=1):
                    batch_size = data[0].shape[0]
                    det_target = data[-1].to(device)

                    det_score = model(*[x.to(device) for x in data[:-1]])
                    det_score = det_score.reshape(-1, cfg.DATA.NUM_CLASSES)
                    det_target = det_target.reshape(-1, cfg.DATA.NUM_CLASSES)
                    det_loss = criterion['MCE'](det_score, det_target)
                    det_losses[phase] += det_loss.item() * batch_size

                    # Output log for current batch
                    pbar.set_postfix({
                        'lr': '{:.7f}'.format(scheduler.get_last_lr()[0]),
                        'det_loss': '{:.5f}'.format(det_loss.item()),
                    })

                    if training:
                        optimizer.zero_grad()
                        det_loss.backward()
                        optimizer.step()
                        scheduler.step()
                    else:
                        # Prepare for evaluation
                        det_score = det_score.softmax(dim=1).cpu().tolist()
                        det_target = det_target.cpu().tolist()
                        det_pred_scores.extend(det_score)
                        det_gt_targets.extend(det_target)
        end = time.time()

        # Output log for current epoch
        log = []
        log.append('Epoch {:2}'.format(epoch))
        log.append('train det_loss: {:.5f}'.format(
            det_losses['train'] / len(data_loaders['train'].dataset),
        ))
        if 'test' in cfg.SOLVER.PHASES:
            # Compute result
            det_result = compute_result['perframe'](
                cfg,
                det_gt_targets,
                det_pred_scores,
            )
            log.append('test det_loss: {:.5f} det_mAP: {:.5f}'.format(
                det_losses['test'] / len(data_loaders['test'].dataset),
                det_result['mean_AP'],
            ))
        log.append('running time: {:.2f} sec'.format(
            end - start,
        ))
        logger.info(' | '.join(log))
        # Save checkpoint for model and optimizer
        if epoch % 5==0 or epoch<=10:
            checkpointer.save(epoch, model, optimizer,num_tasks)
            tet_checkpointer_task_root = osp.join(tet_checkpointer_root,
                                                  'task_' + str(num_tasks) + "_epoch-" + str(epoch) + ".pth")
            tet_checkpointer = setup_checkpointer(cfg, tet_checkpointer_task_root, phase='test')
            tet_checkpointer.load(tet_model, tet_optimizer)
            mean_AP=do_inference(
                cfg,
                tet_dataloaders,
                tet_model,
                tet_device,
                logger,
                num_tasks,
                inferr=0
            )
            if mean_AP>best_meanAP:
                checkpointer.save_best(epoch, tet_model, tet_optimizer, num_tasks)
                best_meanAP=mean_AP


        # Shuffle dataset for next epoch
        data_loaders['train'].dataset.shuffle()
