import numpy as np
import pandas as pd
import os
import json
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from tqdm import tqdm
from utils import *

fps = 2  # desired annotation frequency in Hz, native is 10
step = 10//fps
data_root = '/path/to/KITTI/training/label_02'
out_root = '/your/out/dir'
save_dir = os.path.join(out_root, 'full')
os.makedirs(save_dir, exist_ok=True)

# KITTI annotation format
columns = [
    "frame", "track_id", "type", "truncated", "occluded", "alpha", "bbox_xmin", "bbox_ymin", "bbox_xmax", "bbox_ymax", 
    "dim_height", "dim_width", "dim_length", "loc_x", "loc_y", "loc_z", "rotation_y", "score"
    ]
# Admissible vehicle labels
vehicle_labels = ['Car', 'Truck', 'Bus', 'Van']


for sequence_id in tqdm(os.listdir(data_root)):
    sequence_file = os.path.join(data_root, sequence_id)
    sequence_ann = pd.read_csv(sequence_file, sep=" ", names=columns, header=None)

    # Read IMU poses to obtain ego-motion
    imu_data = read_gps_imu_file(sequence_file.replace('label', 'oxts'))

    # Read IMU values at time 0 (Camera Reference Frame 0 will be used as the global reference frame)
    lat_0 = imu_data['lat'].iloc[[0]].item()
    lon_0 = imu_data['lon'].iloc[[0]].item()
    alt_0 = imu_data['alt'].iloc[[0]].item()
    roll_0 = imu_data['roll'].iloc[[0]].item()
    pitch_0 = imu_data['pitch'].iloc[[0]].item()
    yaw_0 = imu_data['yaw'].iloc[[0]].item()

    # Iterate over frames and swap relative coordinates with absolute ones
    for frame_num in range(0, max(sequence_ann['frame']) + 1, step):
        # Get boxes coordinates and rotations
        translations_xyz = np.vstack((
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['loc_x']),
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['loc_y']),
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['loc_z'])
            )).T        
        rotations_yaw = np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['rotation_y'])
        sizes_wlh = np.vstack((
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['dim_width']),
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['dim_length']),
            np.array(sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num)]['dim_height'])
            )).T
        if len(translations_xyz) == 0:
            continue
        
        # Read IMU values at current time
        lat_t = imu_data['lat'].iloc[[frame_num]].item()
        lon_t = imu_data['lon'].iloc[[frame_num]].item()
        alt_t = imu_data['alt'].iloc[[frame_num]].item()
        roll_t = imu_data['roll'].iloc[[frame_num]].item()
        pitch_t = imu_data['pitch'].iloc[[frame_num]].item()
        yaw_t = imu_data['yaw'].iloc[[frame_num]].item()

        # Read camera extrinsics
        T_velo_cam, T_imu_velo = read_calib_file(sequence_file.replace('label', 'calib'))

        # Transform points in camera frame of timestep 0
        translations_xyz = np.hstack((translations_xyz, np.ones((len(translations_xyz), 1)))).T
        translations_xyz = np.linalg.inv(T_imu_velo) @ np.linalg.inv(T_velo_cam) @ translations_xyz
        translations_xyz = imut_to_imu0(lat_0, lon_0, alt_0, roll_0, pitch_0, yaw_0, 
                                        lat_t, lon_t, alt_t, roll_t, pitch_t, yaw_t, translations_xyz)
        translations_xyz = T_imu_velo @ translations_xyz
        translations_xyz = (T_velo_cam @ translations_xyz)[[0, 2, 1, 3], :]
        translations_xyz = translations_xyz[:-1, :].T

        # Correct rotations
        rotations_yaw *= -1
        d_rotation = yaw_t - yaw_0
        rotations_yaw += d_rotation

        # Put the transformed values back in the dataframe
        sequence_ann.loc[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num), 'loc_x'] = translations_xyz[:, 0]
        sequence_ann.loc[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num), 'loc_y'] = translations_xyz[:, 1]
        sequence_ann.loc[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num), 'loc_z'] = translations_xyz[:, 2]
        sequence_ann.loc[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] == frame_num), 'rotation_y'] = rotations_yaw

        
    # Convert into Car4Cast format
    sequence_ann = sequence_ann[(sequence_ann['type'].isin(vehicle_labels)) & (sequence_ann['frame'] % step == 0)]
    
    # Skip if there are no valid instances
    if len(sequence_ann) == 0:
        continue

    scene_dict = sequence_ann.groupby('track_id').apply(
        lambda group: {
            'timestep': (group['frame'] // step).tolist(),
            'rotation': [[0, 0, val] for val in group['rotation_y']],
            'translation': [[row['loc_x'], row['loc_y'], row['loc_z']] for _, row in group.iterrows()],
            'size': [[row['dim_height'], row['dim_width'], row['dim_length']] for _, row in group.iterrows()],
            'attribute_label': group['type'].tolist()
        }
    ).to_dict()

    # Save scene dict as JSON
    with open(os.path.join(save_dir, sequence_id.replace('.txt', '.json')), 'w') as json_file:
        json.dump(scene_dict, json_file, indent=4)
