import shutil
from create_data_utils import *
from nuscenes.nuscenes import NuScenes as V2XSimDataset
import tqdm
import cv2
import pickle
import argparse

def data_parser():
    parser = argparse.ArgumentParser(description="synthetic data generation")
    parser.add_argument("--dataroot", default='',
                        type=str, help='root path of V2XSim dataset')
    parser.add_argument('--save_dir', default='',
                        type=str, help='save path of reformat V2XSim dataset')
    parser.add_argument("--x_size", default=512,
                        type=int, help="image width")
    parser.add_argument("--y_size", default=512,
                        type=int, help="image height")
    opt = parser.parse_args()
    return opt

if __name__ == "__main__":
    opt = data_parser()
    dataroot = opt.dataroot
    save_dir = opt.save_dir
    x_size = opt.x_size
    y_size = opt.y_size

    chmk_dir(save_dir)
    v2x_sim = V2XSimDataset(version='v2.0', dataroot=dataroot, verbose=True)
    sensors_cav = ['LIDAR_TOP_id_', 'CAM_FRONT_id_', 'CAM_FRONT_LEFT_id_', 'CAM_FRONT_RIGHT_id_', 'CAM_BACK_id_',
                   'CAM_BACK_LEFT_id_',
                   'CAM_BACK_RIGHT_id_']
    sensors_rsu = ['LIDAR_TOP_id_0', 'CAM_id_0_0', 'CAM_id_0_1', 'CAM_id_0_2', 'CAM_id_0_3']

    '''retrieve all scene'''
    pbar = tqdm.tqdm(total=100, leave=True)
    for scene_idx in range(100):
        if scene_idx < 80:
            train_dir = os.path.join(save_dir, 'train')
        else:
            train_dir = os.path.join(save_dir, 'test')
        chmk_dir(train_dir)
        my_scene = v2x_sim.scene[scene_idx]
        scene_dir = os.path.join(train_dir, ''.join(my_scene['name'].split('_')))
        chmk_dir(scene_dir)

        first_sample_token = my_scene['first_sample_token']
        my_sample = v2x_sim.get('sample', first_sample_token)

        '''retrieve all timestamp'''
        timestamp = 0
        while my_sample['next'] != '':
            # timestamp=my_sample['timestamp']
            ann_boxes = v2x_sim.get_boxes(my_sample['data']['LIDAR_TOP_id_0'])
            '''retrieve all agent'''
            for agent_idx in range(6):

                if agent_idx == 0:
                    sensors = sensors_rsu
                else:
                    sensors = [sensor + str(agent_idx) for sensor in sensors_cav]

                # Check agent is included in the scence or not
                is_agent = set(sensors).issubset(set(my_sample['data']))
                if is_agent:
                    # Create agent directory
                    agent_dir = os.path.join(scene_dir, str(agent_idx))
                    chmk_dir(agent_dir)

                    param_save_dict = dict()

                    for (i, sensor) in enumerate(sensors):
                        sampled_data = v2x_sim.get('sample_data', my_sample['data'][sensor])
                        old_path = os.path.join(dataroot, sampled_data['filename'])

                        if i == 0:  # For lidar
                            new_path = os.path.join(agent_dir, '%06d.bin' % timestamp)
                            if not os.path.isfile(new_path):
                                shutil.copy(old_path, new_path)

                            # The first sensor is lidar, which server as the pose of ego
                            ego_pose = v2x_sim.get("ego_pose", sampled_data["ego_pose_token"])
                            save_ego_pose = ego_pose['translation'][:]
                            save_ego_pose.extend(covert_degree(ego_pose['rotation']))

                            # Calibrate lidar pose
                            calibrated_pose = v2x_sim.get('calibrated_sensor', sampled_data['calibrated_sensor_token'])
                            save_calibrated_pose = calibrated_pose['translation'][:]
                            save_calibrated_pose.extend(covert_degree(calibrated_pose['rotation']))

                            # get lidar pose
                            lidar_pose = get_lidar_world_pose(ego_pose, calibrated_pose)

                            # get bev segmentation label
                            bev_data = v2x_sim.get('sample_data', my_sample['data']['BEV_TOP_id_%d'%agent_idx])
                            bev_data_path = os.path.join(dataroot, bev_data['filename'])
                            bev_label = get_seg_label(bev_data_path)
                            new_bev_path = os.path.join(agent_dir, '%06d_bev.npy' % timestamp)
                            if not os.path.isfile(new_bev_path):
                                np.save(new_bev_path, bev_label)

                        else:  # For cameras
                            new_path = os.path.join(agent_dir, '%06d_camera%d.npy' % (timestamp, (i-1)))
                            if not os.path.isfile(new_path):
                                img = cv2.imread(old_path)
                                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                                img = cv2.resize(img, (x_size, y_size))
                                np.save(new_path, img)

                            calibrate_param = v2x_sim.get("calibrated_sensor",
                                                          sampled_data['calibrated_sensor_token'])
                            cam_pose_list, cam_extrinsic = get_transformation_matrix(ego_pose, calibrate_param)
                            param_save_dict['camera%d' % (i-1)] = dict()
                            param_save_dict['camera%d' % (i-1)]['cords'] = cam_pose_list
                            param_save_dict['camera%d' % (i-1)]['extrinsic'] = cam_extrinsic.tolist()
                            param_save_dict['camera%d' % (i-1)]['intrinsic'] = calibrate_param['camera_intrinsic']

                    # Record ego_pose and calibrated_pose
                    param_save_dict['lidar_pose'] = lidar_pose
                    param_save_dict['true_ego_pose'] = save_ego_pose  # TODO: lidar_pose
                    param_save_dict['calibrated_pose'] = save_calibrated_pose

                    # Get annotations
                    param_save_dict['vehicles'] = get_ego_annotations(save_ego_pose, ann_boxes, range=32.5)
                    params_dir = os.path.join(agent_dir, '%06d.pkl' % timestamp)
                    with open(params_dir, 'wb') as file:
                        # A new file will be created
                        pickle.dump(param_save_dict, file)

            my_sample = v2x_sim.get('sample', my_sample['next'])
            timestamp = timestamp + 1
        pbar.set_description("[Finish scence %d]" % (scene_idx))
        pbar.update(1)