# -*- coding: utf-8 -*-
# Author: Runsheng Xu <rxx3386@ucla.edu>
# License: TDG-Attribution-NonCommercial-NoDistrib


import os
import argparse
from torch.utils.data import DataLoader

from opencood.hypes_yaml.yaml_utils import load_yaml
from opencood.visualization import vis_utils
from opencood.data_utils.datasets.early_fusion_vis_dataset import \
    EarlyFusionVisDataset


def vis_parser():
    parser = argparse.ArgumentParser(description="data visualization")
    parser.add_argument('--color_mode', type=str, default="intensity",
                        help='lidar color rendering mode, e.g. intensity,'
                             'z-value or constant.')
    parser.add_argument('--system', type=str, default="V2X",
                        help='V2X or V2V or I2X')
    opt = parser.parse_args()
    return opt


if __name__ == '__main__':
    opt = vis_parser()
    current_path = os.path.dirname(os.path.realpath(__file__))
    params = load_yaml(os.path.join(current_path,
                                    '../hypes_yaml/visualization.yaml'))    
    
    current_path = "/home/user/V2X-M2C/opencood"
    data_path = "/dataset/V2XSet_I/validate_seperate/v2xset/seperate"

    save_path = os.path.join(current_path, opt.system+"_VIS")
    if not os.path.exists(save_path):
            if not os.path.exists(save_path):
                try:
                    os.makedirs(save_path)
                except FileExistsError:
                    pass
    
    if opt.system == "V2V":
        seperate_dir = ['two_01_17', "two_19_19", "two_22_47", "two_51_24", "two_58_19", "three_47_19"]
        
    elif opt.system == "V2X":
        seperate_dir = ['two_01_17', "two_17_21", "two_19_19", "two_22_47", "two_51_24", "two_53_32",
                        "two_58_19", "three_07_10", "three_47_19", "four_12_49"]        
        
    elif opt.system == "I2X":
        seperate_dir = ["two_17_21", "two_53_32", "two_58_19", "three_07_10", "three_47_19", "four_12_49"]
        
    for dir in seperate_dir:
        div_dir = dir.split('_')
        
        cav_num = div_dir[0]
        number_scene = div_dir[1] + '_' + div_dir[2]
        
        params['validate_dir'] = os.path.join(data_path, dir)
        
        if cav_num == 'two':
            params['train_params']['max_cav'] = 2
            # params['model']['args']['max_cav'] = 2
            
        elif cav_num == 'three':
            params['train_params']['max_cav'] = 3
            # params['model']['args']['max_cav'] = 3
            
        elif cav_num == 'four':
            params['train_params']['max_cav'] = 4
            # params['model']['args']['max_cav'] = 4
            
        opencda_dataset = EarlyFusionVisDataset(params, visualize=True,
                                            work='val')
        data_loader = DataLoader(opencda_dataset, batch_size=1, num_workers=8,
                                collate_fn=opencda_dataset.collate_batch_train,
                                shuffle=False,
                                pin_memory=False)

        save_path_ = os.path.join(save_path, number_scene)
        if not os.path.exists(save_path_):
            if not os.path.exists(save_path_):
                try:
                    os.makedirs(save_path_)
                except FileExistsError:
                    pass
        # save the yaml file

        vis_utils.visualize_sequence_dataloader(data_loader,
                                                params['postprocess']['order'],
                                                save_path=save_path_,
                                                color_mode=opt.color_mode)
