"""
Script to convert Aloha hdf5 data to the LeRobot dataset v2.0 format.

Example usage: uv run examples/aloha_real/convert_aloha_data_to_lerobot.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""

import dataclasses
from pathlib import Path
from runpy import run_path
import shutil
from typing import Literal

import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import torch
import tqdm
import tyro


@dataclasses.dataclass(frozen=True)
class DatasetConfig:
    use_videos: bool = True
    tolerance_s: float = 0.0001
    image_writer_processes: int = 10
    image_writer_threads: int = 5
    video_backend: str | None = None


DEFAULT_DATASET_CONFIG = DatasetConfig()


def create_empty_dataset(
    repo_id: str,
    robot_type: str,
    mode: Literal["video", "image"] = "video",
    *,
    has_velocity: bool = False,
    has_effort: bool = False,
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
    motors = [
        'tcp_position_x',
        'tcp_position_y',
        'tcp_position_z',
        'tcp_orientation_x',
        'tcp_orientation_y',
        'tcp_orientation_z',
        'gripper',
    ]
    cameras = [
        "rgb_static",
        "rgb_gripper",
    ]

    features = {
        "observation.state": {
            "dtype": "float32",
            "shape": (len(motors)+1,),
            "names": [
                motors,
            ],
        },
        "action": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
    }

    if has_velocity:
        features["observation.velocity"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    if has_effort:
        features["observation.effort"] = {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        }

    features[f"observation.images.rgb_static"] = {
        "dtype": mode,
        "shape": (3, 200, 200),
        "names": [
            "channels",
            "height",
            "width",
        ],
    }

    features[f"observation.images.rgb_gripper"] = {
        "dtype": mode,
        "shape": (3, 84, 84),
        "names": [
            "channels",
            "height",
            "width",
        ],
    }

    if Path(HF_LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(HF_LEROBOT_HOME / repo_id)

    return LeRobotDataset.create(
        repo_id=repo_id,
        fps=10,
        robot_type=robot_type,
        features=features,
        use_videos=dataset_config.use_videos,
        tolerance_s=dataset_config.tolerance_s,
        image_writer_processes=dataset_config.image_writer_processes,
        image_writer_threads=dataset_config.image_writer_threads,
        video_backend=dataset_config.video_backend,
    )


def get_cameras(hdf5_files: list[Path]) -> list[str]:
    with h5py.File(hdf5_files[0], "r") as ep:
        # ignore depth channel, not currently handled
        return [key for key in ep["/observations/images"].keys() if "depth" not in key]  # noqa: SIM118


def has_velocity(hdf5_files: list[Path]) -> bool:
    with h5py.File(hdf5_files[0], "r") as ep:
        return "/observations/qvel" in ep


def has_effort(hdf5_files: list[Path]) -> bool:
    with h5py.File(hdf5_files[0], "r") as ep:
        return "/observations/effort" in ep

def populate_dataset(
    dataset: LeRobotDataset,
    start_end_list,
    prompts,
    npz_path,
    task: str,
    episodes: list[int] | None = None,
) -> LeRobotDataset:
    if episodes is None:
        episodes = range(len(start_end_list))

    img_cameras = ['rgb_static', 'rgb_gripper'] 
    for ep_idx in tqdm.tqdm(episodes):
        start, end = start_end_list[ep_idx]
        num_frames = end - start + 1 

        for i in range(num_frames):
            with np.load(npz_path / f'episode_{start+i:07d}.npz') as ep:
                frame = {
                    "observation.state": torch.from_numpy(np.concatenate([ep['robot_obs'][:7], ep['robot_obs'][14:]])).float(),
                    # "action": torch.from_numpy(ep["rel_actions"]),
                    "action": torch.from_numpy(ep["rel_actions"][:7]).float(),
                    "task": prompts[ep_idx],
                }

                for camera in img_cameras:
                    frame[f"observation.images.{camera}"] = ep[camera]

                dataset.add_frame(frame)
        dataset.save_episode()

    return dataset

def port_calvin(
    raw_dir: Path,
    lang_path: Path,
    repo_id: str,
    raw_repo_id: str | None = None,
    task: str = "DEBUG",
    *,
    episodes: list[int] | None = None,
    push_to_hub: bool = True,
    is_mobile: bool = False,
    mode: Literal["video", "image"] = "image",
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
    if (HF_LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(HF_LEROBOT_HOME / repo_id)
    print(raw_dir)

    # npz_files = sorted(raw_dir.glob("episode_*.npz"))
    language_file = np.load(lang_path, allow_pickle=True).item()
    start_end_list = language_file['info']['indx']
    prompts = language_file['language']['ann']

    dataset = create_empty_dataset(
        repo_id,
        robot_type="calvin",
        mode=mode,
        has_effort=False,
        has_velocity=False,
        dataset_config=dataset_config,
    )
    dataset = populate_dataset(
        dataset,
        start_end_list,
        prompts,
        raw_dir,
        task=task,
        episodes=episodes,
    )
    dataset.consolidate()

    if push_to_hub:
        dataset.push_to_hub()


if __name__ == "__main__":
    tyro.cli(port_calvin)
