# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>, Hao Xiang <haxiang@g.ucla.edu>, Yifan Lu <yifan_lu@sjtu.edu.cn>
# License: TDG-Attribution-NonCommercial-NoDistrib


import argparse
import os
import time
from tqdm import tqdm

import torch
import open3d as o3d
from torch.utils.data import DataLoader

import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.visualization import vis_utils
import matplotlib.pyplot as plt


def test_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument('--model_dir', type=str, required=True,
                        help='Continued training path')
    parser.add_argument('--fusion_method', required=True, type=str,
                        default='late',
                        help='late, early or intermediate')
    parser.add_argument('--show_vis', action='store_true',
                        help='whether to show image visualization result')
    parser.add_argument('--show_sequence', action='store_true',
                        help='whether to show video visualization result.'
                             'it can note be set true with show_vis together ')
    parser.add_argument('--save_vis', action='store_true',
                        help='whether to save visualization result')
    parser.add_argument('--save_vis_seperate', action='store_true',
                        help='whether to save visualization result')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result'
                             'in npy_test file')
    parser.add_argument('--global_sort_detections', action='store_true',
                        help='whether to globally sort detections by confidence score.'
                             'If set to True, it is the mainstream AP computing method,'
                             'but would increase the tolerance for FP (False Positives).')
    parser.add_argument('--system', type=str, default="V2X",
                        help='V2X or V2V or I2X')
    opt = parser.parse_args()
    return opt


def main():
    opt = test_parser()
    assert opt.fusion_method in ['late', 'early', 'intermediate']
    assert not (opt.show_vis and opt.show_sequence), 'you can only visualize ' \
                                                    'the results in single ' \
                                                    'image mode or video mode'

    hypes = yaml_utils.load_yaml(None, opt)
    
    saved_path = opt.model_dir    
    data_dir = "/dataset/V2XSet_I/validate_seperate/v2xset/seperate"
    
    # if opt.system == "V2V":
    #     seperate_dir = ['two_01_17', "two_19_19", "two_22_47", "two_51_24", "two_58_19", "three_47_19"]
        
    # elif opt.system == "V2X":
    #     seperate_dir = ['two_01_17', "two_17_21", "two_19_19", "two_22_47", "two_51_24", "two_53_32",
    #                     "two_58_19", "three_07_10", "three_47_19", "four_12_49"]        
        
    # elif opt.system == "I2X":
    seperate_dir = ["two_17_21", "two_53_32", "two_58_19", "three_07_10", "three_47_19", "four_12_49"]
        
    # scenario_list = os.listdir(data_dir)
    
    for scenario in seperate_dir:
    
        val_dir = os.path.join(data_dir, scenario)
        div_scenario = scenario.split('_')
        
        cav_num = div_scenario[0]
        number_scene = div_scenario[1] + '_' + div_scenario[2]
        
        hypes['validate_dir'] = val_dir
        
        if cav_num == 'two':
            hypes['train_params']['max_cav'] = 2
            hypes['model']['args']['max_cav'] = 2
            
        elif cav_num == 'three':
            hypes['train_params']['max_cav'] = 3
            hypes['model']['args']['max_cav'] = 3
            
        elif cav_num == 'four':
            hypes['train_params']['max_cav'] = 4
            hypes['model']['args']['max_cav'] = 4
            
        check_max_cav = hypes['train_params']['max_cav']
        print(f'Start Scenario {scenario} === Setted max_cav {check_max_cav}')

        print('Dataset Building')
        opencood_dataset = build_dataset(hypes, visualize=True, work='val')
        print(f"{len(opencood_dataset)} samples found.")
        data_loader = DataLoader(opencood_dataset,
                                batch_size=1,
                                num_workers=16,
                                collate_fn=opencood_dataset.collate_batch_test,
                                shuffle=False,
                                pin_memory=False,
                                drop_last=False)

        print('Creating Model')
        model = train_utils.create_model(hypes)
        # we assume gpu is necessary
        if torch.cuda.is_available():
            model.cuda()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        print('Loading Model from checkpoint')
        saved_path = opt.model_dir
        _, model = train_utils.load_saved_model(saved_path, model)
        model.eval()

        # Create the dictionary for evaluation.
        # also store the confidence score for each prediction
        result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                    0.5: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                    0.7: {'tp': [], 'fp': [], 'gt': 0, 'score': []}}

        if opt.show_sequence:
            vis = o3d.visualization.Visualizer()
            vis.create_window()

            vis.get_render_option().background_color = [0.05, 0.05, 0.05]
            vis.get_render_option().point_size = 1.0
            vis.get_render_option().show_coordinate_frame = True

            # used to visualize lidar points
            vis_pcd = o3d.geometry.PointCloud()
            # used to visualize object bounding box, maximum 50
            vis_aabbs_gt = []
            vis_aabbs_pred = []
            for _ in range(50):
                vis_aabbs_gt.append(o3d.geometry.LineSet())
                vis_aabbs_pred.append(o3d.geometry.LineSet())

        for i, batch_data in tqdm(enumerate(data_loader)):
            # print(i)
            with torch.no_grad():
                batch_data = train_utils.to_device(batch_data, device)
                if opt.fusion_method == 'late':
                    pred_box_tensor, pred_score, gt_box_tensor = \
                        inference_utils.inference_late_fusion(batch_data,
                                                            model,
                                                            opencood_dataset)
                elif opt.fusion_method == 'early':
                    pred_box_tensor, pred_score, gt_box_tensor = \
                        inference_utils.inference_early_fusion(batch_data,
                                                            model,
                                                            opencood_dataset)
                elif opt.fusion_method == 'intermediate':
                    pred_box_tensor, pred_score, gt_box_tensor = \
                        inference_utils.inference_intermediate_fusion(batch_data,
                                                                    model,
                                                                    opencood_dataset)
                else:
                    raise NotImplementedError('Only early, late and intermediate'
                                            'fusion is supported.')

                eval_utils.caluclate_tp_fp(pred_box_tensor,
                                        pred_score,
                                        gt_box_tensor,
                                        result_stat,
                                        0.3)
                eval_utils.caluclate_tp_fp(pred_box_tensor,
                                        pred_score,
                                        gt_box_tensor,
                                        result_stat,
                                        0.5)
                eval_utils.caluclate_tp_fp(pred_box_tensor,
                                        pred_score,
                                        gt_box_tensor,
                                        result_stat,
                                        0.7)
                if opt.save_npy:
                    npy_save_path = os.path.join(opt.model_dir, 'npy')
                    if not os.path.exists(npy_save_path):
                        os.makedirs(npy_save_path)
                    inference_utils.save_prediction_gt(pred_box_tensor,
                                                    gt_box_tensor,
                                                    batch_data['ego'][
                                                        'origin_lidar'][0],
                                                    i,
                                                    npy_save_path)

                if opt.show_vis or opt.save_vis:
                    vis_save_path = ''
                    if opt.save_vis:
                        vis_save_path = os.path.join(opt.model_dir, 'vis')
                        if not os.path.exists(vis_save_path):
                            os.makedirs(vis_save_path)
                        vis_save_path = os.path.join(vis_save_path, '%05d.png' % i)

                    opencood_dataset.visualize_result(pred_box_tensor,
                                                    gt_box_tensor,
                                                    batch_data['ego'][
                                                        'origin_lidar'],
                                                    opt.show_vis,
                                                    vis_save_path,
                                                    dataset=opencood_dataset)
                    
                if opt.show_vis or opt.save_vis_seperate:
                    vis_save_path = ''
                    if opt.save_vis_seperate:
                        vis_save_path = os.path.join(opt.model_dir, f'{opt.system}_VIS/{number_scene}')
                        if not os.path.exists(vis_save_path):
                            os.makedirs(vis_save_path)
                        vis_save_path = os.path.join(vis_save_path, '%05d.png' % i)

                    opencood_dataset.seperate_visualize_result(pred_box_tensor,
                                                            gt_box_tensor,
                                                            batch_data['ego'][
                                                                'each_lidar'],
                                                            opt.show_vis,
                                                            vis_save_path,
                                                            dataset=opencood_dataset)

                if opt.show_sequence:
                    pcd, pred_o3d_box, gt_o3d_box = \
                        vis_utils.visualize_inference_sample_dataloader(
                            pred_box_tensor,
                            gt_box_tensor,
                            batch_data['ego']['origin_lidar'],
                            vis_pcd,
                            mode='constant'
                            )
                    if i == 0:
                        vis.add_geometry(pcd)
                        vis_utils.linset_assign_list(vis,
                                                    vis_aabbs_pred,
                                                    pred_o3d_box,
                                                    update_mode='add')

                        vis_utils.linset_assign_list(vis,
                                                    vis_aabbs_gt,
                                                    gt_o3d_box,
                                                    update_mode='add')

                    vis_utils.linset_assign_list(vis,
                                                vis_aabbs_pred,
                                                pred_o3d_box)
                    vis_utils.linset_assign_list(vis,
                                                vis_aabbs_gt,
                                                gt_o3d_box)
                    vis.update_geometry(pcd)
                    vis.poll_events()
                    vis.update_renderer()
                    time.sleep(0.001)

        eval_utils.eval_final_results(result_stat,
                                    opt.model_dir,
                                    opt.global_sort_detections)
        if opt.show_sequence:
            vis.destroy_window()


if __name__ == '__main__':
    main()
