# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import os
import os.path as osp
from tqdm import tqdm

import torch
import numpy as np
import pickle as pkl

from rekognition_online_action_detection.datasets import build_dataset
from rekognition_online_action_detection.evaluation import compute_result


def do_perframe_det_batch_inference(cfg,
            data_loader,
            model,
            device,
            logger,
            num_task,
            inferr
                                    ):
    # Setup model to test mode
    model.eval()
    # Collect scores and targets
    pred_scores = {}
    gt_targets = {}
    save_root='checkpoints/THUMOS/LSTR/lstr_long_512_work_8_kinetics_1x'
    with torch.no_grad():
        pbar = tqdm(data_loader, desc="task"+str(num_task)+'BatchInference')
        for batch_idx, data in enumerate(pbar, start=1):
            target = data[-4]

            score = model(*[x.to(device) for x in data[:-4]])
            score = score.softmax(dim=-1).cpu().numpy()

            for bs, (session, query_indices, num_frames) in enumerate(zip(*data[-3:])):
                if session not in pred_scores:
                    pred_scores[session] = np.zeros((num_frames, cfg.DATA.NUM_CLASSES))
                if session not in gt_targets:
                    gt_targets[session] = np.zeros((num_frames, cfg.DATA.NUM_CLASSES))

                if query_indices[0] == 0:
                    pred_scores[session][query_indices] = score[bs]
                    gt_targets[session][query_indices] = target[bs]
                else:
                    pred_scores[session][query_indices[-1]] = score[bs][-1]
                    gt_targets[session][query_indices[-1]] = target[bs][-1]

    # Save scores and targets
    if inferr == 1:
        pkl.dump({
            'cfg': cfg,
            'perframe_pred_scores': pred_scores,
            'perframe_gt_targets': gt_targets,
        }, open(osp.join(save_root,"task_"+str(num_task) +"_best"+ '.pkl'), 'wb'))

    # Compute results
    result = compute_result['perframe'](
        cfg,
        np.concatenate(list(gt_targets.values()), axis=0),
        np.concatenate(list(pred_scores.values()), axis=0),
    )
    logger.info("task_"+str(num_task)+':Action detection perframe m{}: {:.5f}'.format(
        cfg.DATA.METRICS, result['mean_AP']
    ))
    return result['mean_AP']
