# -*- coding: utf-8 -*-
# Author: Yifan Lu <yifan_lu@sjtu.edu.cn>, Runsheng Xu <rxx3386@ucla.edu>, Hao Xiang <haxiang@g.ucla.edu>,
# License: TDG-Attribution-NonCommercial-NoDistrib

import argparse
import copy
import os
import time
from typing import OrderedDict
import importlib
import torch
import open3d as o3d
from torch.utils.data import DataLoader, Subset
import numpy as np
import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.visualization import vis_utils, my_vis, simple_vis
from opencood.utils.common_utils import update_dict
torch.multiprocessing.set_sharing_strategy('file_system')

def test_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--hypes_yaml", "-y", type=str,
                        help='data generation yaml file needed ')
    parser.add_argument('--model_dir', type=str, required=True,
                        help='Continued training path')
    parser.add_argument('--config_lsd', '-lsd', type=int, default=0,
                        help='whether use lsd config')
    
    parser.add_argument('--collab_mode', '-mode', type=str, default='general',
                        help='whether collaborate in public or local sematic space')
    
    parser.add_argument('--use_cb', type=str, default='False',
                        help='whether collaborate in public or local sematic space')
    
    parser.add_argument('--eval_epoch', type=int, help='use epoch')
    parser.add_argument('--eval_epoch_comm', type=int, help='use comm epoch')
    
    parser.add_argument('--fusion_method', type=str,
                        default='intermediate',
                        help='no, no_w_uncertainty, late, early or intermediate')
    parser.add_argument('--save_vis_interval', type=int, default=100,
                        help='interval of saving visualization')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result'
                             'in npy file')
    parser.add_argument('--range', type=str, default="102.4,102.4",
                        help="detection range is [-102.4, +102.4, -102.4, +102.4]")
    parser.add_argument('--no_score', action='store_true',
                        help="whether print the score of prediction")
    parser.add_argument('--note', default="", type=str, help="any other thing?")
    parser.add_argument('--note_comm', default="", type=str, help="comm net epoch")
    opt = parser.parse_args()
    return opt


def main():
    opt = test_parser()

    assert opt.fusion_method in ['late', 'early', 'intermediate', 'no', 'no_w_uncertainty', 'single'] 
    
    hypes = yaml_utils.load_yaml(None, opt)
    hypes['fusion']['core_method'] = 'intermediatehetersubdata'
    if opt.collab_mode == 'pub' or opt.collab_mode == 'direct_pub':
        hypes['fusion']['core_method'] = 'intermediateheter'
        hypes['model']['core_method'] = hypes['model']['core_method']['inf_pub']
        hypes['model']['args'].update({'comm_space': opt.collab_mode})

    if  'm' in opt.collab_mode:
        hypes['fusion']['core_method'] = 'intermediateheter'
        mapping_dict = hypes['heter']['mapping_dict']
        for src in mapping_dict.keys():
            mapping_dict[src] = opt.collab_mode
        
        if opt.use_cb == 'True':
            hypes['model']['core_method'] = hypes['model']['core_method']['inf_local']
            hypes['model']['args'][opt.collab_mode]['allied'] = True 
            hypes['model']['args'].update({'comm_space': opt.collab_mode})

    if 'heter' in hypes:
        # hypes['heter']['lidar_channels'] = 16
        # opt.note += "_16ch"

        # x_min, x_max = -102.4, 102.4
        # y_min, y_max = -51.2, 51.2

        x_min, x_max = -eval(opt.range.split(',')[0]), eval(opt.range.split(',')[0])
        y_min, y_max = -eval(opt.range.split(',')[1]), eval(opt.range.split(',')[1])
        opt.note += f"_{x_max}_{y_max}"

        new_cav_range = [x_min, y_min, hypes['postprocess']['anchor_args']['cav_lidar_range'][2], \
                            x_max, y_max, hypes['postprocess']['anchor_args']['cav_lidar_range'][5]]

        # replace all appearance
        hypes = update_dict(hypes, {
            "cav_lidar_range": new_cav_range,
            "lidar_range": new_cav_range,
            "gt_range": new_cav_range
        })

        # reload anchor
        yaml_utils_lib = importlib.import_module("opencood.hypes_yaml.yaml_utils")
        for name, func in yaml_utils_lib.__dict__.items():
            if name == hypes["yaml_parser"]:
                parser_func = func
        hypes = parser_func(hypes)

        
    
    hypes['validate_dir'] = hypes['test_dir']
    if "OPV2V" in hypes['test_dir'] or "v2xsim" in hypes['test_dir']:
        assert "test" in hypes['validate_dir']
    
    # This is used in visualization
    # left hand: OPV2V, V2XSet
    # right hand: V2X-Sim 2.0 and DAIR-V2X
    left_hand = True if ("OPV2V" in hypes['test_dir'] or "V2XSET" in hypes['test_dir']) else False

    print(f"Left hand visualizing: {left_hand}")

    if 'box_align' in hypes.keys():
        hypes['box_align']['val_result'] = hypes['box_align']['test_result']

    print('Creating Model')
    model = train_utils.create_model(hypes)
    # we assume gpu is necessary
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    print('Loading Model from checkpoint')
    saved_path = opt.model_dir
    
    if opt.use_cb == 'False': # 普通模型, 直接加载参数
        resume_epoch, model = train_utils.load_saved_model(saved_path, model, opt.eval_epoch)
        print('collab in valliance mode')
        
    elif opt.use_cb == 'True': # 若为mx, 或者公共空间pub
        print('Collab with codebook')
        if 'm' in opt.collab_mode:
            resume_epoch, model = train_utils.load_modality_saved_model(saved_path, model, opt.collab_mode, True, opt.eval_epoch)
        elif opt.collab_mode == 'pub' or opt.collab_mode == 'direct_pub': # pub加载全量参数
            resume_epoch, model = train_utils.load_saved_model(saved_path, model, opt.eval_epoch)
        
        model.pub_cb_path = os.path.join(opt.model_dir, model.pub_cb_path)
        model.pub_query_emb_path = os.path.join(opt.model_dir, model.pub_query_emb_path)
        
        if resume_epoch == 0: # 若无预训练的协作模型, 从local模型加载
            modality_type_list = set(hypes['heter']['mapping_dict'].values())
            for modality_name in modality_type_list:
                _, model = train_utils.load_modality_saved_model( \
                    os.path.join(opt.model_dir, hypes['model']['args'][modality_name]['model_dir']), \
                    model, modality_name, True)
                        # if model.use_alliance:
            model.pub_codebook = torch.load(model.pub_cb_path)
            model.pub_query_embeddings = torch.load(model.pub_query_emb_path)
            resume_epoch = '_indcomb' # individual combine
        else:
            model = train_utils.load_pub_cb(model, resume_epoch)
    else:
        raise ValueError('Args of opt.use_cb with "False" or "True" only')      
        
        
        
    print(f"resume from {resume_epoch} epoch.")
    opt.note += f"_epoch{resume_epoch}"
    
    
    
        #     if epoch is None:
        #     initial_epoch = findLastCheckpoint(saved_path)
        # else:
        #     initial_epoch = int(epoch)
            
        # if initial_epoch > 0:
        #     print('resuming by loading epoch %d' % (initial_epoch))
        # else: 
        #     return initial_epoch, model
    
    if opt.config_lsd:
        resume_epoch, model = train_utils.load_saved_model_comm(saved_path, model, opt.eval_epoch_comm)
        print(f"comm resume from {resume_epoch} epoch.")
        opt.note_comm = opt.note + f"_comm_ep{resume_epoch}"
        
    if torch.cuda.is_available():
        model.cuda()
    # device = torch.device('cpu')
    
    model.eval()

    # setting noise
    np.random.seed(303)
    
    # build dataset for each noise setting
    print('Dataset Building')
    opencood_dataset = build_dataset(hypes, visualize=True, train=False)
    # opencood_dataset_subset = Subset(opencood_dataset, range(640,2100))
    # data_loader = DataLoader(opencood_dataset_subset,
    data_loader = DataLoader(opencood_dataset,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=opencood_dataset.collate_batch_test,
                            shuffle=False,
                            pin_memory=False,
                            drop_last=False)
    
    
    # Create the dictionary for evaluation
    result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                0.5: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                0.7: {'tp': [], 'fp': [], 'gt': 0, 'score': []}}


    # 子数据集标签
    subdataset_result_stat_dict = {
        'd1': copy.deepcopy(result_stat),
        'd2': copy.deepcopy(result_stat),
        'd3': copy.deepcopy(result_stat),
        'd4': copy.deepcopy(result_stat),
    }
    
    if opt.config_lsd == 1:
        infer_info = opt.fusion_method + opt.note_comm
    elif opt.config_lsd == 0:
        infer_info = opt.fusion_method + opt.note



    for i, batch_data in enumerate(data_loader):
        print(f"{infer_info}_{i}")
        if batch_data is None:
            continue
        with torch.no_grad():
            batch_data = train_utils.to_device(batch_data, device)

            if opt.fusion_method == 'late':
                infer_result = inference_utils.inference_late_fusion(batch_data,
                                                        model,
                                                        opencood_dataset)
            elif opt.fusion_method == 'early':
                infer_result = inference_utils.inference_early_fusion(batch_data,
                                                        model,
                                                        opencood_dataset)
            elif opt.fusion_method == 'intermediate':
                infer_result = inference_utils.inference_intermediate_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'no':
                infer_result = inference_utils.inference_no_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'no_w_uncertainty':
                infer_result = inference_utils.inference_no_fusion_w_uncertainty(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'single':
                infer_result = inference_utils.inference_no_fusion(batch_data,
                                                                model,
                                                                opencood_dataset,
                                                                single_gt=True)
            else:
                raise NotImplementedError('Only single, no, no_w_uncertainty, early, late and intermediate'
                                        'fusion is supported.')

            pred_box_tensor = infer_result['pred_box_tensor']
            gt_box_tensor = infer_result['gt_box_tensor']
            pred_score = infer_result['pred_score']
            
            subdataset_label = batch_data['ego']['subdataset_label_list'][0] # inf时batch_size=1, 取第一个ego的label即可
            # print(subdataset_label)
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    subdataset_result_stat_dict[subdataset_label],
                                    0.3)
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    subdataset_result_stat_dict[subdataset_label],
                                    0.5)
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    subdataset_result_stat_dict[subdataset_label],
                                    0.7)
            if opt.save_npy:
                npy_save_path = os.path.join(opt.model_dir, 'npy')
                if not os.path.exists(npy_save_path):
                    os.makedirs(npy_save_path)
                inference_utils.save_prediction_gt(pred_box_tensor,
                                                gt_box_tensor,
                                                batch_data['ego'][
                                                    'origin_lidar'][0],
                                                i,
                                                npy_save_path)

            if not opt.no_score:
                infer_result.update({'score_tensor': pred_score})

            if getattr(opencood_dataset, "heterogeneous", False):
                cav_box_np, agent_modality_list = inference_utils.get_cav_box(batch_data)
                infer_result.update({"cav_box_np": cav_box_np, \
                                     "agent_modality_list": agent_modality_list})

            if (i % opt.save_vis_interval == 0) and (pred_box_tensor is not None or gt_box_tensor is not None):
                if opt.config_lsd == 0:
                    vis_save_path_root = os.path.join(opt.model_dir, f'vis_{infer_info}')
                if opt.config_lsd == 1:
                    vis_save_path_root = os.path.join(opt.model_dir, 'comm_module', f'vis_{infer_info}')
                    
                if not os.path.exists(vis_save_path_root):
                    os.makedirs(vis_save_path_root)

                # vis_save_path = os.path.join(vis_save_path_root, '3d_%05d.png' % i)
                # simple_vis.visualize(infer_result,
                #                     batch_data['ego'][
                #                         'origin_lidar'][0],
                #                     hypes['postprocess']['gt_range'],
                #                     vis_save_path,
                #                     method='3d',
                #                     left_hand=left_hand)
                 
                vis_save_path = os.path.join(vis_save_path_root, 'bev_%05d.png' % i)
                simple_vis.visualize(infer_result,
                                    batch_data['ego'][
                                        'origin_lidar'][0],
                                    hypes['postprocess']['gt_range'],
                                    vis_save_path,
                                    method='bev',
                                    left_hand=left_hand)
        torch.cuda.empty_cache()
        # break

    for subdataset_label in subdataset_result_stat_dict.keys():

        ap_30, ap_50, ap_70 = eval_utils.eval_final_results(subdataset_result_stat_dict[subdataset_label],
                                    opt.model_dir, infer_info)
        
        result_file = 'result.txt'
        if opt.config_lsd:
            result_file = 'comm_module/result.txt'
        with open(os.path.join(saved_path, result_file), 'a+') as f:
            msg = ''
            msg += f'Results on sub dataset {subdataset_label}: '
            if ('m' in opt.collab_mode) or (opt.collab_mode == 'pub')\
                or (opt.collab_mode == 'direct_pub') or (opt.collab_mode == 'hete'): 
                msg += f'{opt.collab_mode}: '
            msg += 'Epoch: {} | AP @0.3: {:.04f} | AP @0.5: {:.04f} | AP @0.7: {:.04f}\n'.format(resume_epoch, ap_30, ap_50, ap_70)
            f.write(msg)
            print(msg)

if __name__ == '__main__':
    main()
