import numpy as np
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from pathlib import Path
from kornia.geometry.linalg import transform_points
from tqdm import tqdm
from av2.torch.data_loaders.detection import DetectionDataLoader

from utils import hexbin_plot, compute_alphashape

root_dir = '/your/root/dir'
dataset_name = 'argoverse3D'
split_name = 'train'
out_dir = '/your/out/dir/maps'
os.makedirs(out_dir, exist_ok=True)
os.makedirs(out_dir.replace('maps', 'bounds'), exist_ok=True)
os.makedirs(out_dir.replace('maps', 'maps_with_bounds'), exist_ok=True)


def main(
    root_dir: Path = root_dir,
    dataset_name: str = dataset_name,
    split_name: str = split_name,
    num_accumulated_sweeps: int = 1,
) -> None:
    """Iterate over the detection data-loader.

    Dataset should live at {root_dir}/{dataset_name}/sensor/{split_name}.

    Args:
        root_dir: Root directory to the datasets.
        dataset_name: Name of the dataset (e.g., "av2").
        split_name: Name of the split (e.g., "val").
        num_accumulated_sweeps: Number of sweeps to accumulate.
    """
    data_loader = DetectionDataLoader(
        root_dir,
        dataset_name,
        split_name,
        num_accumulated_sweeps=num_accumulated_sweeps,
    )
    
    pointclouds = []
    for i, sweep in enumerate(tqdm(data_loader)):
        scene_id, _ = sweep.sweep_uuid

        # 4x4 matrix representing the SE(3) transformation to city from ego-vehicle coordinates.
        city_SE3_ego_mat4 = sweep.city_SE3_ego.matrix()
        
        # Lidar (x,y,z) in meters and intensity (i).
        lidar_tensor = sweep.lidar.as_tensor()

        # Transform the points to city coordinates.
        lidar_xyz_city = transform_points(city_SE3_ego_mat4, lidar_tensor[:, :3]).numpy().T
        pointclouds.append(lidar_xyz_city)

        # If the scene is different at the next step, save accumulated pointcloud
        next_scene_id = data_loader[i+1].sweep_uuid[0] if i+1 < len(data_loader) else None
        if scene_id != next_scene_id:
            pointclouds = np.hstack(pointclouds)
            concave_hull = compute_alphashape(pointclouds, save_path=os.path.join(out_dir.replace('maps', 'bounds'), f'{scene_id}.pkl'), downsample=300)
            hexbin_plot(pointclouds, save_path_png=os.path.join(out_dir.replace('maps', 'maps_with_bounds'), f'{scene_id}.png'), shape=concave_hull)
            pointclouds = []


if __name__ == "__main__":
    main()
