import torch
import os
import json
import math
import pandas as pd

from pathlib import Path
from kornia.geometry.linalg import transform_points
from tqdm import tqdm

from av2.torch.data_loaders.detection import DetectionDataLoader


fps = 2  # desired annotation frequency in Hz, native is 10
step = 10//fps
root_dir = '/your/root/dir/'
dataset_name = 'argoverse3D'
split_name = 'train'
out_dir = '/your/our/dir/'
os.makedirs(out_dir, exist_ok=True)


def main(
    root_dir: Path = root_dir,
    dataset_name: str = dataset_name,
    split_name: str = split_name,
    num_accumulated_sweeps: int = 1,
) -> None:
    """Iterate over the detection data-loader.

    Dataset should live at {root_dir}/{dataset_name}/sensor/{split_name}.

    Args:
        root_dir: Root directory to the datasets.
        dataset_name: Name of the dataset (e.g., "av2").
        split_name: Name of the split (e.g., "val").
        num_accumulated_sweeps: Number of sweeps to accumulate.
    """
    data_loader = DetectionDataLoader(
        root_dir,
        dataset_name,
        split_name,
        num_accumulated_sweeps=num_accumulated_sweeps,
    )

    vehicle_categories = ['REGULAR_VEHICLE', 'LARGE_VEHICLE', 'BUS', 'BOX_TRUCK', 'TRUCK', 'VEHICULAR_TRAILER',
                           'TRUCK_CAB', 'SCHOOL_BUS', 'ARTICULATED_BUS', 'MESSAGE_BOARD_TRAILER']
    
    scene_df_list = []
    local_time = 0
    for i, sweep in enumerate(tqdm(data_loader)):
        if i % step != 0:
            continue

        scene_id, timestamp = sweep.sweep_uuid        

        # 4x4 matrix representing the SE(3) transformation to city from ego-vehicle coordinates.
        city_SE3_ego_mat4 = sweep.city_SE3_ego.matrix()
        
        # Get corresponding ego Euler angles
        q_w, q_x, q_y, q_z = sweep.city_SE3_ego.quaternion[0].data.detach()
        roll_ego = torch.atan2(2.0 * (q_w * q_x + q_y * q_z), 1.0 - 2.0 * (q_x**2 + q_y**2))
        pitch_ego = torch.asin(2.0 * (q_w * q_y - q_z * q_x))
        yaw_ego = torch.atan2(2.0 * (q_w * q_z + q_x * q_y), 1.0 - 2.0 * (q_y**2 + q_z**2))

        # Annotations in (x,y,z,l,w,h,yaw) format.
        xyzlwh_t = sweep.cuboids.as_tensor()

        # Access cuboid category.
        category = sweep.cuboids.category
        vehicle_idxs = [index for index, value in enumerate(category) if value in vehicle_categories]
        category = [elem for i, elem in enumerate(category) if i in vehicle_idxs]

        # Access track uuid (i.e., unique instance_id).
        track_uuid = sweep.cuboids.track_uuid
        track_uuid = [elem for i, elem in enumerate(track_uuid) if i in vehicle_idxs]
        
        # Transform bbox centroids from ego to city reference frame
        xyz_t_city = transform_points(city_SE3_ego_mat4, xyzlwh_t[vehicle_idxs, :3]).numpy().tolist()

        # Correct yaw angles with ego yaw angle
        yaw_t = xyzlwh_t[vehicle_idxs, -1] + yaw_ego
        rpy_t = torch.stack([yaw_t, yaw_t, yaw_t], axis=1)
        rpy_t[:, 0] = roll_ego
        rpy_t[:, 1] = pitch_ego
        rpy_t = rpy_t.numpy().tolist()

        # Access sizes
        sizes_hwl = xyzlwh_t[vehicle_idxs, :][:, [5, 4, 3]]
        sizes_hwl = sizes_hwl.numpy().tolist()

        # Create a dataframe for the current sweep (i.e., frame)
        sweep_df = pd.DataFrame({
            'timestep': [local_time]*len(vehicle_idxs),
            'translation': xyz_t_city,
            'rotation': rpy_t,
            'size': sizes_hwl,
            'attribute_label': category
        }, index=track_uuid)
        scene_df_list.append(sweep_df)
        
        local_time += 1

        # If the scene is different at the next step, combine scene dataframes and save corresponding scene_dict
        next_scene_id = data_loader[i+step].sweep_uuid[0] if i+step < len(data_loader) else None
        if scene_id != next_scene_id:
            scene_df = pd.concat(scene_df_list)
            scene_df = scene_df.groupby(scene_df.index).agg({
                'timestep': lambda x: list(x),
                'translation': lambda x: list(x),
                'rotation': lambda x: list(x),
                'size': lambda x: list(x),
                'attribute_label': lambda x: list(x),
            }).agg(list)
            # Save scene dict as JSON
            scene_dict = scene_df.to_dict(orient='index')
            with open(os.path.join(out_dir, f'{scene_id}.json'), 'w') as json_file:
                json.dump(scene_dict, json_file, indent=4)
            # Reset scene_df list and local time
            scene_df_list = []
            local_time = 0

if __name__ == "__main__":
    main()
