# -*- coding: utf-8 -*-
# Author: Yifan Lu <yifan_lu@sjtu.edu.cn>, Runsheng Xu <rxx3386@ucla.edu>, Hao Xiang <haxiang@g.ucla.edu>,
# License: TDG-Attribution-NonCommercial-NoDistrib

import argparse
import os
import time
import importlib
import torch
import open3d as o3d
from torch.utils.data import DataLoader, Subset
import numpy as np
import matplotlib
from tqdm import tqdm
matplotlib.use("agg")
import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.visualization import vis_utils, my_vis, simple_vis
from opencood.utils.common_utils import update_dict
torch.multiprocessing.set_sharing_strategy('file_system')

def test_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--hypes_yaml", "-y", type=str,
                        help='data generation yaml file needed ')
    parser.add_argument('--model_dir', type=str, required=True,
                        help='Continued training path')
    parser.add_argument('--config_lsd', '-lsd', type=int, default=0,
                        help='whether use lsd config')
    
    parser.add_argument('--collab_mode', '-mode', type=str, default='general',
                        help='whether collaborate in public or local sematic space')
    
    parser.add_argument('--use_cb', type=str, default='False',
                        help='whether collaborate in public or local sematic space')
    
    parser.add_argument('--eval_epoch', type=int, help='use epoch')
    parser.add_argument('--eval_epoch_comm', type=int, help='use comm epoch')
    
    parser.add_argument('--fusion_method', type=str,
                        default='intermediate',
                        help='no, no_w_uncertainty, late, early or intermediate')
    parser.add_argument('--save_vis_interval', type=int, default=100,
                        help='interval of saving visualization')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result'
                             'in npy file')
    parser.add_argument('--range', type=str, default="102.4,102.4",
                        help="detection range is [-102.4, +102.4, -102.4, +102.4]")
    parser.add_argument('--no_score', action='store_true',
                        help="whether print the score of prediction")
    parser.add_argument('--note', default="", type=str, help="any other thing?")
    parser.add_argument('--note_comm', default="", type=str, help="comm net epoch")
    opt = parser.parse_args()
    return opt


def main():
    opt = test_parser()

    assert opt.fusion_method in ['late', 'early', 'intermediate', 'no', 'no_w_uncertainty', 'single'] 
    
    hypes = yaml_utils.load_yaml(None, opt)
    
    if opt.collab_mode == 'pub' or opt.collab_mode == 'direct_pub':
        hypes['fusion']['core_method'] = 'intermediateheter'
        hypes['model']['core_method'] = hypes['model']['core_method']['inf_pub']
        hypes['model']['args'].update({'comm_space': opt.collab_mode})

    if  'm' in opt.collab_mode:
        hypes['fusion']['core_method'] = 'intermediateheter'
        mapping_dict = hypes['heter']['mapping_dict']
        for src in mapping_dict.keys():
            mapping_dict[src] = opt.collab_mode
        
        if opt.use_cb == 'True':
            hypes['model']['core_method'] = hypes['model']['core_method']['inf_local']
            hypes['model']['args'][opt.collab_mode]['allied'] = True 
            hypes['model']['args'].update({'comm_space': opt.collab_mode})

    if 'heter' in hypes:
        # hypes['heter']['lidar_channels'] = 16
        # opt.note += "_16ch"

        # x_min, x_max = -102.4, 102.4
        # y_min, y_max = -51.2, 51.2

        x_min, x_max = -eval(opt.range.split(',')[0]), eval(opt.range.split(',')[0])
        y_min, y_max = -eval(opt.range.split(',')[1]), eval(opt.range.split(',')[1])
        opt.note += f"_{x_max}_{y_max}"

        new_cav_range = [x_min, y_min, hypes['postprocess']['anchor_args']['cav_lidar_range'][2], \
                            x_max, y_max, hypes['postprocess']['anchor_args']['cav_lidar_range'][5]]

        # replace all appearance
        hypes = update_dict(hypes, {
            "cav_lidar_range": new_cav_range,
            "lidar_range": new_cav_range,
            "gt_range": new_cav_range,
            "local_range": new_cav_range
        })

        # reload anchor
        yaml_utils_lib = importlib.import_module("opencood.hypes_yaml.yaml_utils")
        for name, func in yaml_utils_lib.__dict__.items():
            if name == hypes["yaml_parser"]:
                parser_func = func
        hypes = parser_func(hypes)

        
    
    hypes['validate_dir'] = hypes['test_dir']
    if "OPV2V" in hypes['test_dir'] or "v2xsim" in hypes['test_dir']:
        assert "test" in hypes['validate_dir']
    
    # This is used in visualization
    # left hand: OPV2V, V2XSet
    # right hand: V2X-Sim 2.0 and DAIR-V2X
    left_hand = True if ("OPV2V" in hypes['test_dir'] or "V2XSET" in hypes['test_dir']) else False

    print(f"Left hand visualizing: {left_hand}")

    if 'box_align' in hypes.keys():
        hypes['box_align']['val_result'] = hypes['box_align']['test_result']

    print('Creating Model')
    model = train_utils.create_model(hypes)
    # we assume gpu is necessary
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


    modality_type_list = set(hypes['heter']['mapping_dict'].values())

    print('Loading Model from checkpoint')
    saved_path = opt.model_dir
    
    if opt.use_cb == 'False': # 普通模型, 直接加载参数
        resume_epoch, model = train_utils.load_saved_model(saved_path, model, opt.eval_epoch)
        
        if resume_epoch == 0:   # For domian adapter case         
            # 若文件夹下无预训练模型, 从每个模态的文件夹加载预训练参数
            for modality_name in modality_type_list:
                modality_hypes = hypes['model']['args'][modality_name]
                if 'model_dir' not in modality_hypes:
                    continue
                modality_hypes['model_dir'] = os.path.join(opt.model_dir, modality_hypes['model_dir'])
                
                _, model = train_utils.load_modality_saved_model(\
                    hypes['model']['args'][modality_name]['model_dir'], \
                    model, modality_name)
                print('Load stage0 with trained adapter!')      
        else:
            print('Collab in valliance mode')
        
    elif opt.use_cb == 'True': # 若为mx, 或者公共空间pub
        print('Collab with codebook')
        if 'm' in opt.collab_mode:
            resume_epoch, model = train_utils.load_modality_saved_model(saved_path, model, opt.collab_mode, True, opt.eval_epoch)
        elif opt.collab_mode == 'pub' or opt.collab_mode == 'direct_pub': # pub加载全量参数
            resume_epoch, model = train_utils.load_saved_model(saved_path, model, opt.eval_epoch)
        
        model.pub_cb_path = os.path.join(opt.model_dir, model.pub_cb_path)
        model.pub_query_emb_path = os.path.join(opt.model_dir, model.pub_query_emb_path)
        
        if resume_epoch == 0: # 若无预训练的协作模型, 从local模型加载
            modality_type_list = set(hypes['heter']['mapping_dict'].values())
            for modality_name in modality_type_list:
                _, model = train_utils.load_modality_saved_model( \
                    os.path.join(opt.model_dir, hypes['model']['args'][modality_name]['model_dir']), \
                    model, modality_name, True)
                        # if model.use_alliance:
            model.pub_codebook = torch.load(model.pub_cb_path)
            model.pub_query_embeddings = torch.load(model.pub_query_emb_path)
            resume_epoch = '_indcomb' # individual combine
        else:
            model = train_utils.load_pub_cb(model, resume_epoch)
    else:
        raise ValueError('Args of opt.use_cb with "False" or "True" only')      
        
        
        
    print(f"resume from {resume_epoch} epoch.")
    opt.note += f"_epoch{resume_epoch}"
    
    
    
        #     if epoch is None:
        #     initial_epoch = findLastCheckpoint(saved_path)
        # else:
        #     initial_epoch = int(epoch)
            
        # if initial_epoch > 0:
        #     print('resuming by loading epoch %d' % (initial_epoch))
        # else: 
        #     return initial_epoch, model
    
    if opt.config_lsd:
        resume_epoch, model = train_utils.load_saved_model_comm(saved_path, model, opt.eval_epoch_comm)
        print(f"comm resume from {resume_epoch} epoch.")
        opt.note_comm = opt.note + f"_comm_ep{resume_epoch}"
        
    if torch.cuda.is_available():
        model.cuda()
    # device = torch.device('cpu')
    
    model.eval()

    # setting noise
    np.random.seed(303)
    
    # build dataset for each noise setting
    print('Dataset Building')
    opencood_dataset = build_dataset(hypes, visualize=True, train=False)
    # opencood_dataset_subset = Subset(opencood_dataset, range(640,2100))
    # data_loader = DataLoader(opencood_dataset_subset,
    data_loader = DataLoader(opencood_dataset,
                            batch_size=1,
                            num_workers=4,
                            collate_fn=opencood_dataset.collate_batch_test,
                            shuffle=False,
                            pin_memory=False,
                            drop_last=False)
    
    # Create the dictionary for evaluation
    result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                0.5: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                0.7: {'tp': [], 'fp': [], 'gt': 0, 'score': []}}

    
    if opt.config_lsd == 1:
        infer_info = opt.fusion_method + opt.note_comm
    elif opt.config_lsd == 0:
        infer_info = opt.fusion_method + opt.note


    # for i, batch_data in enumerate(data_loader):
    for i, batch_data in tqdm(enumerate(data_loader)):
        print(f"{infer_info}_{i}")
        if batch_data is None:
            continue
        with torch.no_grad():
            batch_data = train_utils.to_device(batch_data, device)

            if opt.fusion_method == 'late':
                infer_result = inference_utils.inference_late_fusion(batch_data,
                                                        model,
                                                        opencood_dataset)
            elif opt.fusion_method == 'early':
                infer_result = inference_utils.inference_early_fusion(batch_data,
                                                        model,
                                                        opencood_dataset)
            elif opt.fusion_method == 'intermediate':
                infer_result = inference_utils.inference_intermediate_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'no':
                infer_result = inference_utils.inference_no_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'no_w_uncertainty':
                infer_result = inference_utils.inference_no_fusion_w_uncertainty(batch_data,
                                                                model,
                                                                opencood_dataset)
            elif opt.fusion_method == 'single':
                infer_result = inference_utils.inference_no_fusion(batch_data,
                                                                model,
                                                                opencood_dataset,
                                                                single_gt=True)
            else:
                raise NotImplementedError('Only single, no, no_w_uncertainty, early, late and intermediate'
                                        'fusion is supported.')

            pred_box_tensor = infer_result['pred_box_tensor']
            gt_box_tensor = infer_result['gt_box_tensor']
            pred_score = infer_result['pred_score']
            
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.3)
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.5)
            eval_utils.caluclate_tp_fp(pred_box_tensor,
                                    pred_score,
                                    gt_box_tensor,
                                    result_stat,
                                    0.7)
            if opt.save_npy:
                npy_save_path = os.path.join(opt.model_dir, 'npy')
                if not os.path.exists(npy_save_path):
                    os.makedirs(npy_save_path)
                inference_utils.save_prediction_gt(pred_box_tensor,
                                                gt_box_tensor,
                                                batch_data['ego'][
                                                    'origin_lidar'][0],
                                                i,
                                                npy_save_path)

            if not opt.no_score:
                infer_result.update({'score_tensor': pred_score})

            if getattr(opencood_dataset, "heterogeneous", False):
                cav_box_np, agent_modality_list = inference_utils.get_cav_box(batch_data)
                infer_result.update({"cav_box_np": cav_box_np, \
                                     "agent_modality_list": agent_modality_list})

            if (i % opt.save_vis_interval == 0 or i == 0) and (pred_box_tensor is not None or gt_box_tensor is not None):
                if opt.config_lsd == 0:
                    vis_save_path_root = os.path.join(opt.model_dir, f'vis_{infer_info}')
                if opt.config_lsd == 1:
                    vis_save_path_root = os.path.join(opt.model_dir, 'comm_module', f'vis_{infer_info}')
                    
                if not os.path.exists(vis_save_path_root):
                    os.makedirs(vis_save_path_root)

                # vis_save_path = os.path.join(vis_save_path_root, '3d_%05d.png' % i)
                # simple_vis.visualize(infer_result,
                #                     batch_data['ego'][
                #                         'origin_lidar'][0],
                #                     hypes['postprocess']['gt_range'],
                #                     vis_save_path,
                #                     method='3d',
                #                     left_hand=left_hand)
                 
                vis_save_path = os.path.join(vis_save_path_root, 'bev_%05d.png' % i)
                
                simple_vis.visualize(infer_result,
                                    batch_data['ego'][
                                        'origin_lidar'][0],
                                    hypes['postprocess']['gt_range'],
                                    vis_save_path,
                                    method='bev',
                                    left_hand=left_hand)
        torch.cuda.empty_cache()
        # break

    ap_30, ap_50, ap_70 = eval_utils.eval_final_results(result_stat,
                                opt.model_dir, infer_info)
    
    result_file = 'result.txt'
    if opt.config_lsd:
        result_file = 'comm_module/result.txt'
    with open(os.path.join(saved_path, result_file), 'a+') as f:
        msg = ''
        if ('m' in opt.collab_mode) or (opt.collab_mode == 'pub')\
            or (opt.collab_mode == 'direct_pub') or (opt.collab_mode == 'hete'): 
            msg += f'{opt.collab_mode}: '
        msg += 'Epoch: {} | AP @0.3: {:.04f} | AP @0.5: {:.04f} | AP @0.7: {:.04f}\n'.format(resume_epoch, ap_30, ap_50, ap_70)
        f.write(msg)
        print(msg)

if __name__ == '__main__':
    main()
