# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>, Hao Xiang <haxiang@g.ucla.edu>, Yifan Lu <yifan_lu@sjtu.edu.cn>
# License: TDG-Attribution-NonCommercial-NoDistrib

import glob
import argparse
import os
import time
from tqdm import tqdm

import torch
torch.multiprocessing.set_sharing_strategy('file_system')
import open3d as o3d
import pandas as pd
from torch.utils.data import DataLoader

import opencood.hypes_yaml.yaml_utils as yaml_utils
from opencood.tools import train_utils, inference_utils
from opencood.data_utils.datasets import build_dataset
from opencood.utils import eval_utils
from opencood.visualization import vis_utils
import matplotlib.pyplot as plt


def test_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument('--model_dir', type=str, required=True,
                        help='Continued training path')
    parser.add_argument('--fusion_method', required=True, type=str,
                        default='late',
                        help='late, early or intermediate')
    parser.add_argument('--noise', required=True, type=str,
                        default='p',
                        help='inference for perfect (p) or noisy (n)')
    parser.add_argument('--show_vis', action='store_true',
                        help='whether to show image visualization result')
    parser.add_argument('--show_sequence', action='store_true',
                        help='whether to show video visualization result.'
                             'it can note be set true with show_vis together ')
    parser.add_argument('--save_vis', action='store_true',
                        help='whether to save visualization result')
    parser.add_argument('--v2v', action='store_true',
                        help='inference for v2v or v2x')
    parser.add_argument('--v2x', action='store_true',
                        help='inference for v2v or v2x')
    parser.add_argument('--save_npy', action='store_true',
                        help='whether to save prediction and gt result'
                             'in npy_test file')
    parser.add_argument('--global_sort_detections', action='store_true',
                        help='whether to globally sort detections by confidence score.'
                             'If set to True, it is the mainstream AP computing method,'
                             'but would increase the tolerance for FP (False Positives).')
    opt = parser.parse_args()
    return opt


def main():
    opt = test_parser()
    assert opt.fusion_method in ['late', 'early', 'intermediate', 'no']
    assert opt.noise in ['p', 'n', 'nr', 'hn', 'nn']
    assert not (opt.show_vis and opt.show_sequence), 'you can only visualize ' \
                                                    'the results in single ' \
                                                    'image mode or video mode'

    hypes = yaml_utils.load_yaml(None, opt)
    if 'max_cav' not in hypes['train_params']:
        hypes['train_params']['max_cav'] = 5
    
    saved_path = opt.model_dir
    
    if opt.noise == 'n' or opt.noise == 'nr' or opt.noise == 'hn' or opt.noise == 'nn':
        hypes['wild_setting']['async'] = True
        hypes['wild_setting']['loc_err'] = True
        seeds = [25, 26, 27, 28, 29]
        # seeds = [29]
        
        if opt.noise == 'nr':    
            hypes['wild_setting']['async_overhead'] = 200
            hypes['wild_setting']['async_mode'] = 'real'
            
        elif opt.noise == 'hn':
            hypes['wild_setting']['async_overhead'] = 500
            hypes['wild_setting']['async_mode'] = 'real'
            hypes['wild_setting']['xyz_std'] = 0.5
            hypes['wild_setting']['ryp_std'] = 1.0
            seeds = [30]
            
        elif opt.noise == 'nn':
            hypes['wild_setting']['async_overhead'] = 200
            hypes['wild_setting']['async_mode'] = 'real'
            hypes['wild_setting']['xyz_std'] = 0.2
            hypes['wild_setting']['ryp_std'] = 0.4
            seeds = [30]
        
    elif opt.noise == 'p':
        hypes['wild_setting']['async'] = False
        hypes['wild_setting']['loc_err'] = False
        seeds = [30]
    
    # if opt.v2v or opt.v2x:
        # hypes['train_params']['max_cav'] = 2
        # hypes['model']['args']['max_cav'] = 2
        # hypes['validate_dir'] = "/dataset/V2XSet_I/validate_seperate/v2xset/two"
        
    hypes['fusion_method'] = opt.fusion_method
    
    for seed in seeds:
        Epoch =[]
        AP_3 = []
        AP_5 = []
        AP_7 = []
        
        hypes['wild_setting']['seed'] = seed
            
        print(f'Dataset Building for seed {seed}')
        opencood_dataset = build_dataset(hypes, visualize=True, work='val')
        print(f"{len(opencood_dataset)} samples found.")
        data_loader = DataLoader(opencood_dataset,
                                batch_size=1,
                                num_workers=16,
                                collate_fn=opencood_dataset.collate_batch_test,
                                shuffle=False,
                                pin_memory=False,
                                drop_last=False)

        print('Creating Model')
        model = train_utils.create_model(hypes)
        # we assume gpu is necessary
        if torch.cuda.is_available():
            model.cuda()
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        print('Loading Model from checkpoint')

        saved_path = opt.model_dir
        file_list = glob.glob(os.path.join(saved_path, '*epoch*.pth'))
        file_list.sort()
        
        if 'loc_err' in hypes['wild_setting']:
            is_noisy = hypes['wild_setting']['loc_err']
        
        mode = hypes['wild_setting']['async_mode'] if 'async_mode' in hypes['wild_setting'] and\
            is_noisy else ''
        latency = hypes['wild_setting']['async_overhead'] if 'async_overhead' in hypes['wild_setting'] and\
            is_noisy else 'no'
        rpy = hypes['wild_setting']['ryp_std'] if 'ryp_std' in hypes['wild_setting'] and \
            is_noisy else ''
        xyz = hypes['wild_setting']['xyz_std'] if 'xyz_std' in hypes['wild_setting'] and \
            is_noisy else ''

        # st = 10
        # ed = 15
        
        st = 1
        ed = 5
        
        # st = 20
        # ed = 25
        
        # st = 25
        # ed = 30
        
        # st = 35
        # ed = 40

        print(f'Inference {file_list[st][-6:-4]} to {file_list[ed-1][-6:-4]}')
            
        for file in file_list[st:ed]:
            print(f'Inference about {file[-6:-4]}th result')
            model = train_utils.load_saved_all_model(file, model)
            model.eval()

            # Create the dictionary for evaluation.
            # also store the confidence score for each prediction
            result_stat = {0.3: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                        0.5: {'tp': [], 'fp': [], 'gt': 0, 'score': []},                
                        0.7: {'tp': [], 'fp': [], 'gt': 0, 'score': []}}

            if opt.show_sequence:
                vis = o3d.visualization.Visualizer()
                vis.create_window()

                vis.get_render_option().background_color = [0.05, 0.05, 0.05]
                vis.get_render_option().point_size = 1.0
                vis.get_render_option().show_coordinate_frame = True

                # used to visualize lidar points
                vis_pcd = o3d.geometry.PointCloud()
                # used to visualize object bounding box, maximum 50
                vis_aabbs_gt = []
                vis_aabbs_pred = []
                for _ in range(50):
                    vis_aabbs_gt.append(o3d.geometry.LineSet())
                    vis_aabbs_pred.append(o3d.geometry.LineSet())

            for i, batch_data in tqdm(enumerate(data_loader)):
                with torch.no_grad():
                    batch_data = train_utils.to_device(batch_data, device)
                    if opt.fusion_method == 'late':
                        pred_box_tensor, pred_score, gt_box_tensor = \
                            inference_utils.inference_late_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
                    elif opt.fusion_method == 'no':
                        pred_box_tensor, pred_score, gt_box_tensor = \
                            inference_utils.inference_no_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)

                    elif opt.fusion_method == 'early':
                        pred_box_tensor, pred_score, gt_box_tensor = \
                            inference_utils.inference_early_fusion(batch_data,
                                                                model,
                                                                opencood_dataset)
                    elif opt.fusion_method == 'intermediate':
                        pred_box_tensor, pred_score, gt_box_tensor = \
                            inference_utils.inference_intermediate_fusion(batch_data,
                                                                        model,
                                                                        opencood_dataset)
                    else:
                        raise NotImplementedError('Only early, late and intermediate'
                                                'fusion is supported.')

                    eval_utils.caluclate_tp_fp(pred_box_tensor,
                                            pred_score,
                                            gt_box_tensor,
                                            result_stat,
                                            0.3)
                    eval_utils.caluclate_tp_fp(pred_box_tensor,
                                            pred_score,
                                            gt_box_tensor,
                                            result_stat,
                                            0.5)
                    eval_utils.caluclate_tp_fp(pred_box_tensor,
                                            pred_score,
                                            gt_box_tensor,
                                            result_stat,
                                            0.7)
                    if opt.save_npy:
                        npy_save_path = os.path.join(opt.model_dir, 'npy')
                        if not os.path.exists(npy_save_path):
                            os.makedirs(npy_save_path)
                        inference_utils.save_prediction_gt(pred_box_tensor,
                                                        gt_box_tensor,
                                                        batch_data['ego'][
                                                            'origin_lidar'][0],
                                                        i,
                                                        npy_save_path)

                    if opt.show_vis or opt.save_vis:
                        vis_save_path = ''
                        if opt.save_vis:
                            vis_save_path = os.path.join(opt.model_dir, 'vis')
                            if not os.path.exists(vis_save_path):
                                os.makedirs(vis_save_path)
                            vis_save_path = os.path.join(vis_save_path, '%05d.png' % i)

                        opencood_dataset.visualize_result(pred_box_tensor,
                                                        gt_box_tensor,
                                                        batch_data['ego'][
                                                            'origin_lidar'],
                                                        opt.show_vis,
                                                        vis_save_path,
                                                        dataset=opencood_dataset)

                    if opt.show_sequence:
                        pcd, pred_o3d_box, gt_o3d_box = \
                            vis_utils.visualize_inference_sample_dataloader(
                                pred_box_tensor,
                                gt_box_tensor,
                                batch_data['ego']['origin_lidar'],
                                vis_pcd,
                                mode='constant'
                                )
                        if i == 0:
                            vis.add_geometry(pcd)
                            vis_utils.linset_assign_list(vis,
                                                        vis_aabbs_pred,
                                                        pred_o3d_box,
                                                        update_mode='add')

                            vis_utils.linset_assign_list(vis,
                                                        vis_aabbs_gt,
                                                        gt_o3d_box,
                                                        update_mode='add')

                        vis_utils.linset_assign_list(vis,
                                                    vis_aabbs_pred,
                                                    pred_o3d_box)
                        vis_utils.linset_assign_list(vis,
                                                    vis_aabbs_gt,
                                                    gt_o3d_box)
                        vis.update_geometry(pcd)
                        vis.poll_events()
                        vis.update_renderer()
                        time.sleep(0.001)

            ap30_, ap50_, ap70_ = eval_utils.eval_final_results(result_stat,
                                                                opt.model_dir,
                                                                opt.global_sort_detections,
                                                                file[-6:-4])
            Epoch.append(file[-6:-4])
            AP_3.append(ap30_)
            AP_5.append(ap50_)
            AP_7.append(ap70_)

            df = pd.DataFrame(Epoch, columns = ['Epoch'])
            df['AP_3'] = AP_3
            df['AP_5'] = AP_5
            df['AP_7'] = AP_7

            if opt.fusion_method == 'no':
                df.to_csv(os.path.join(saved_path, f'no_fusion_{latency}_{rpy}_{xyz}_{mode}_{st}~{ed}.csv'), index=False)
            elif opt.v2v:
                df.to_csv(os.path.join(saved_path, f'V2V_{latency}_{rpy}_{xyz}_{mode}_{st}~{ed}.csv'), index=False)
            elif opt.v2x:
                df.to_csv(os.path.join(saved_path, f'V2X_{latency}_{rpy}_{xyz}_{mode}_{st}~{ed}.csv'), index=False)
            else:
                df.to_csv(os.path.join(saved_path, f'S{seed}_{latency}_{rpy}_{xyz}_{mode}_{st}~{ed}.csv'), index=False)

    if opt.show_sequence:
        vis.destroy_window()


if __name__ == '__main__':
    main()
