# ---------------------------------------------
# Copyright (c) OpenMMLab. All rights reserved.
# ---------------------------------------------
#  Modified by Zhiqi Li
# ---------------------------------------------
#  NuerIPS 2025 submission
#  Anonymous Author(s)
#  We modified the code according to SSR and retain the copyright statement here.
# ---------------------------------------------
import sys
sys.path.append('')
import numpy as np
import argparse
import mmcv
import os
import copy
import torch
torch.multiprocessing.set_sharing_strategy('file_system')
import warnings
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
                         wrap_fp16_model)

from mmdet3d.apis import single_gpu_test
from mmdet3d.datasets import build_dataset
from projects.mmdet3d_plugin.datasets.builder import build_dataloader
from mmdet3d.models import build_model
from mmdet.apis import set_random_seed
# from projects.mmdet3d_plugin.bevformer.apis.test import custom_multi_gpu_test
from projects.mmdet3d_plugin.SSR.apis.test import custom_multi_gpu_test
from projects.mmdet3d_plugin.SSR.utils.plan_loss import *
from projects.mmdet3d_plugin.SSR.SSR_head import obtain_map_information, obtain_curent_traffic_object_information, obtain_future_traffic_object_information
from mmdet.datasets import replace_ImageToTensor
import time
import os.path as osp
import json
import pandas as pd
from tqdm import tqdm

import warnings
warnings.filterwarnings("ignore")

def parse_args():
    parser = argparse.ArgumentParser(
        description='MMDet test (and eval) a model')
    parser.add_argument('config', help='test config file path')
    parser.add_argument('checkpoint', help='checkpoint file')
    parser.add_argument('--json_dir', help='json parent dir name file') # NOTE: json file parent folder name
    parser.add_argument('--out', help='output result file in pickle format')
    parser.add_argument(
        '--fuse-conv-bn',
        action='store_true',
        help='Whether to fuse conv and bn, this will slightly increase'
        'the inference speed')
    parser.add_argument(
        '--format-only',
        action='store_true',
        help='Format the output results without perform evaluation. It is'
        'useful when you want to format the result to a specific format and '
        'submit it to the test server')
    parser.add_argument(
        '--eval',
        type=str,
        nargs='+',
        help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
        ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
    parser.add_argument('--show', action='store_true', help='show results')
    parser.add_argument(
        '--show-dir', help='directory where results will be saved')
    parser.add_argument(
        '--gpu-collect',
        action='store_true',
        help='whether to use gpu to collect results.')
    parser.add_argument(
        '--tmpdir',
        help='tmp directory used for collecting results from multiple '
        'workers, available when gpu-collect is not specified')
    parser.add_argument('--seed', type=int, default=0, help='random seed')
    parser.add_argument(
        '--deterministic',
        action='store_true',
        help='whether to set deterministic options for CUDNN backend.')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function (deprecate), '
        'change to --eval-options instead.')
    parser.add_argument(
        '--eval-options',
        nargs='+',
        action=DictAction,
        help='custom options for evaluation, the key-value pair in xxx=yyy '
        'format will be kwargs for dataset.evaluate() function')
    parser.add_argument(
        '--launcher',
        choices=['none', 'pytorch', 'slurm', 'mpi'],
        default='none',
        help='job launcher')
    parser.add_argument('--local_rank', type=int, default=0)
    args = parser.parse_args()
    if 'LOCAL_RANK' not in os.environ:
        os.environ['LOCAL_RANK'] = str(args.local_rank)

    if args.options and args.eval_options:
        raise ValueError(
            '--options and --eval-options cannot be both specified, '
            '--options is deprecated in favor of --eval-options')
    if args.options:
        warnings.warn('--options is deprecated in favor of --eval-options')
        args.eval_options = args.options
    return args


def main():
    args = parse_args()

    assert args.out or args.eval or args.format_only or args.show \
        or args.show_dir, \
        ('Please specify at least one operation (save/eval/format/show the '
         'results / save the results) with the argument "--out", "--eval"'
         ', "--format-only", "--show" or "--show-dir"')

    if args.eval and args.format_only:
        raise ValueError('--eval and --format_only cannot be both specified')

    if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
        raise ValueError('The output file must be a pkl file.')

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)
    # import modules from string list.
    if cfg.get('custom_imports', None):
        from mmcv.utils import import_modules_from_strings
        import_modules_from_strings(**cfg['custom_imports'])

    # import modules from plguin/xx, registry will be updated
    if hasattr(cfg, 'plugin'):
        if cfg.plugin:
            import importlib
            if hasattr(cfg, 'plugin_dir'):
                plugin_dir = cfg.plugin_dir
                _module_dir = os.path.dirname(plugin_dir)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]

                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)
            else:
                # import dir is the dirpath for the config file
                _module_dir = os.path.dirname(args.config)
                _module_dir = _module_dir.split('/')
                _module_path = _module_dir[0]
                for m in _module_dir[1:]:
                    _module_path = _module_path + '.' + m
                print(_module_path)
                plg_lib = importlib.import_module(_module_path)

    # set cudnn_benchmark
    if cfg.get('cudnn_benchmark', False):
        torch.backends.cudnn.benchmark = True

    cfg.model.pretrained = None
    # in case the test dataset is concatenated
    samples_per_gpu = 1
    if isinstance(cfg.data.test, dict):
        cfg.data.test.test_mode = True
        samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
        if samples_per_gpu > 1:
            # Replace 'ImageToTensor' to 'DefaultFormatBundle'
            cfg.data.test.pipeline = replace_ImageToTensor(
                cfg.data.test.pipeline)
    elif isinstance(cfg.data.test, list):
        for ds_cfg in cfg.data.test:
            ds_cfg.test_mode = True
        samples_per_gpu = max(
            [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
        if samples_per_gpu > 1:
            for ds_cfg in cfg.data.test:
                ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)

    # init distributed env first, since logger depends on the dist info.
    if args.launcher == 'none':
        distributed = False
    else:
        distributed = True
        init_dist(args.launcher, **cfg.dist_params)

    # set random seeds
    if args.seed is not None:
        set_random_seed(args.seed, deterministic=args.deterministic)

    # Build the dataloader
    dataset = build_dataset(cfg.data.test)
    data_loader = build_dataloader(
        dataset,
        samples_per_gpu=samples_per_gpu,
        workers_per_gpu=cfg.data.workers_per_gpu,
        dist=distributed,
        shuffle=False,
        nonshuffler_sampler=cfg.data.nonshuffler_sampler,
    )

    # Extract configuration for rewards
    pts_bbox_head_cfg = cfg.model.pts_bbox_head
    critic_reward_bound = pts_bbox_head_cfg["critic_reward_bound"]
    critic_reward_col = pts_bbox_head_cfg["critic_reward_col"]
    intrinsic_reward_bound = pts_bbox_head_cfg["intrinsic_reward_bound"]
    intrinsic_reward_col = pts_bbox_head_cfg["intrinsic_reward_col"]

    # Initialize reward operations
    critic_reward_imitation_op = CriticImitationConstrain()
    critic_reward_endpoint_op = CriticEndPointConstrain()
    critic_reward_bound_op = CriticMapBoundConstrain(**critic_reward_bound)
    critic_reward_col_op = CriticCollisionConstrain(**critic_reward_col)
    intrinsic_reward_bound_op = IntrinsicMapBoundConstrain(**intrinsic_reward_bound)
    intrinsic_reward_col_op = IntrinsicCollisionConstrain(**intrinsic_reward_col)

    # Reward table initialization
    reward_table = {
        "scene_token": [],
        "frame_idx": [],
        "imitation_reward": [],
        "endpoint_reward": [],
        "critic_bound_reward": [],
        "critic_collision_reward": [],
        "intrinsic_bound_reward": [],
        "intrinsic_collision_reward": [],
    }

    # Save path setup
    save_dir = "/opt/nvme1/zhengxj/projects/world-model-RL-store/SSR/reward"
    file_name = "reward_test_2.xlsx"
    save_path = os.path.join(save_dir, file_name)

    # Process data
    for data in tqdm(data_loader):
        # Extract data components
        scene_token = data["img_metas"][0].data[0][0]["scene_token"]
        frame_idx = data["frame_idx"][0].item()
        # if (scene_token != "2ca15f59d656489a8b1a0be4d9bead4e" or frame_idx not in [1, 8, 9]):
        #     continue

        map_gt_bboxes_list = data["map_gt_bboxes_3d"].data[0]
        map_gt_labels_list = data["map_gt_labels_3d"].data[0]
        gt_bboxes_list = data["gt_bboxes_3d"][0].data[0]
        gt_labels_list = data["gt_labels_3d"][0].data[0]
        gt_attr_labels = data["gt_attr_labels"][0].data[0]  # List[Tensor] [N, 34]
        ego_fut_trajs = data["ego_fut_trajs"][0].data[0][0]  # [1, 6, 2]
        ego_fut_masks = data["ego_fut_masks"][0].data[0][0][0]  # [1, 6]
        ego_fut_cmd = data["ego_fut_cmd"][0].data[0][0][0]  # [1, 3]

        device = ego_fut_cmd.device

        # Obtain map and traffic object information
        map_gt, map_type_gt = obtain_map_information(
            map_gt_bboxes_list, map_gt_labels_list, num_classes=3, device=device
        )
        agent_cur_position, agent_type = obtain_curent_traffic_object_information(
            gt_bboxes_list, gt_labels_list, device=device
        )  # (B, N, 2), (B, N, 10)
        agent_fut_position, agent_fut_mask = obtain_future_traffic_object_information(
            gt_attr_labels, fut_ts=6, device=device
        )  # (B, N, T, 2), (B, N, T)

        # Calculate rewards
        critic_reward_imitation = critic_reward_imitation_op(
            ego_fut_trajs, ego_fut_trajs, ego_fut_masks
        )
        critic_reward_endpoint = critic_reward_endpoint_op(
            ego_fut_trajs, ego_fut_trajs, ego_fut_masks
        )
        critic_reward_bound = critic_reward_bound_op(
            ego_fut_trajs, map_gt, map_type_gt, reward_augment=False
        )
        critic_reward_col = critic_reward_col_op(
            ego_fut_trajs,
            agent_cur_position,
            agent_type,
            agent_fut_position,
            agent_fut_mask,
        )
        intrinsic_reward_bound = intrinsic_reward_bound_op(map_gt, map_type_gt)
        intrinsic_reward_col = intrinsic_reward_col_op(agent_cur_position, agent_type)

        # Update reward table
        reward_table["scene_token"].append(scene_token)
        reward_table["frame_idx"].append(frame_idx)
        reward_table["imitation_reward"].append(
            critic_reward_imitation.mean().cpu().item()
        )
        reward_table["endpoint_reward"].append(
            critic_reward_endpoint.mean().cpu().item()
        )
        reward_table["critic_bound_reward"].append(
            critic_reward_bound.mean().cpu().item()
        )
        reward_table["critic_collision_reward"].append(
            critic_reward_col.mean().cpu().item()
        )
        reward_table["intrinsic_bound_reward"].append(
            intrinsic_reward_bound.mean().cpu().item()
        )
        reward_table["intrinsic_collision_reward"].append(
            intrinsic_reward_col.mean().cpu().item()
        )

    # Save rewards to an Excel file
    df = pd.DataFrame(reward_table)
    os.makedirs(save_dir, exist_ok=True)  # Ensure the save directory exists
    df.to_excel(save_path, index=False)


if __name__ == '__main__':
    main()

# python /opt/nvme0/zhengxj/projects/world-model-RL-ac/SSR/tools/test_reward.py projects/configs/SSR/SSR_e2e.py epoch_12_ema.pth --launcher=none --eval=bbox --tmpdir=tmp