"""
Script to convert R1 hdf5 data to the LeRobot dataset v2.0 format.

Example usage: uv run examples/r1/convert_r1_data_to_lerobot_single.py --raw-dir /path/to/raw/data --repo-id <org>/<dataset-name>
"""

import dataclasses
from pathlib import Path
import shutil
from typing import Literal
import random
import h5py
from lerobot.common.datasets.lerobot_dataset import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
import numpy as np
import torch
import tqdm
import tyro


@dataclasses.dataclass(frozen=True)
class DatasetConfig:
    use_videos: bool = True
    tolerance_s: float = 0.0001
    image_writer_processes: int = 10
    image_writer_threads: int = 5
    video_backend: str | None = None


DEFAULT_DATASET_CONFIG = DatasetConfig()


def create_empty_dataset(
    repo_id: str,
    robot_type: str,
    mode: Literal["video", "image"] = "video",
    *,
    has_velocity: bool = False,
    has_effort: bool = False,
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
) -> LeRobotDataset:
    motors = [
        "left_arm_0",
        "left_arm_1",
        "left_arm_2",
        "left_arm_3",
        "left_arm_4",
        "left_arm_5",
        "left_gripper",
        "mobile_base_0",
        "mobile_base_1",
        "mobile_base_2",
        "right_arm_0",
        "right_arm_1",
        "right_arm_2",
        "right_arm_3",
        "right_arm_4",
        "right_arm_5",
        "right_gripper",
        "torso_0",
        "torso_1",
        "torso_2",
        "torso_3",
    ]
    action_motors = [
        'left_arm_0',
        'left_arm_1',
        'left_arm_2',
        'left_arm_3',
        'left_arm_4',
        'left_arm_5',
        'left_gripper',
        'mobile_base_0',
        'mobile_base_1',
        'mobile_base_2',
        'right_arm_0',
        'right_arm_1',
        'right_arm_2',
        'right_arm_3',
        'right_arm_4',
        'right_arm_5',
        'right_gripper',
        'torso_0',
        'torso_1',
        'torso_2',
        'torso_3',
    ]
    cameras = [
        "head",
        "left_wrist",
        "right_wrist",
    ]

    features = {
        "observation.state": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
        "observation.next_state": {
            "dtype": "float32",
            "shape": (len(motors),),
            "names": [
                motors,
            ],
        },
        "action": {
            "dtype": "float32",
            "shape": (len(action_motors),),
            "names": [
                action_motors
            ],
        },
        "next_action": {
            "dtype": "float32",
            "shape": (len(action_motors),),
            "names": [
                action_motors
            ],
        },
        "reward": {
            "dtype": "float32",
            "shape": (1,),
            "names": [
                ["reward"]
            ],
        },
        "terminal": {
            "dtype": "bool",
            "shape": (1,),
            "names": [
                ["terminal"]
            ],
        },
    }

    # Add camera features with simplified structure
    for cam in cameras:
        features[f"observation.images.{cam}"] = {
            "dtype": mode,
            "shape": (270, 480, 3),
            "names": [
                "height",
                "width",
                "channels",
            ],
        }
        features[f"observation.images.next_{cam}"] = {
            "dtype": mode,
            "shape": (270, 480, 3),
            "names": [
                "height",
                "width",
                "channels",
            ],
        }

    if Path(HF_LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(HF_LEROBOT_HOME / repo_id)

    return LeRobotDataset.create(
        repo_id=repo_id,
        fps=50,
        robot_type=robot_type,
        features=features,
        use_videos=dataset_config.use_videos,
        tolerance_s=dataset_config.tolerance_s,
        image_writer_processes=dataset_config.image_writer_processes,
        image_writer_threads=dataset_config.image_writer_threads,
        video_backend=dataset_config.video_backend,
    )


def load_raw_images_per_camera(ep: h5py.File, cameras: list[str]) -> dict[str, np.ndarray]:
    imgs_per_cam = {}

    for camera in cameras:
        # print(ep, camera),quit()
        uncompressed = ep[f"/obs/rgb/{camera}/img"].ndim == 4

        if uncompressed:
            # load all images in RAM
            imgs_array = ep[f"/obs/rgb/{camera}/img"][:]
        else:
            import cv2

            # load one compressed image after the other in RAM and uncompress
            imgs_array = []
            for data in ep[f"/obs/rgb/{camera}/img"]:
                img = cv2.imdecode(data, 1)
                imgs_array.append(img)
            imgs_array = np.array(imgs_array)

        imgs_per_cam[camera] = imgs_array
    return imgs_per_cam


def load_state(ep: h5py.File):
    states = []
    states.append(ep['obs/joint_state/left_arm/joint_position'][:, :6])
    states.append(ep['obs/gripper_state/left_gripper/gripper_position'][:].reshape(-1, 1))
    states.append(ep['obs/chassis_odom/linear_velocity'][:, :2])
    states.append(ep['obs/chassis_odom/angular_velocity'][:, -1:])
    states.append(ep['obs/joint_state/right_arm/joint_position'][:, :6])
    states.append(ep['obs/gripper_state/right_gripper/gripper_position'][:].reshape(-1, 1))
    states.append(ep['obs/joint_state/torso/joint_position'][:])

    states = np.concatenate(states, axis=-1)
    return torch.from_numpy(states)

def load_reward(ep_path):
    reward_path = str(ep_path).replace('rollout', 'reward').replace('.h5', '_reward.npy').replace('_no_depth', '')
    reward = np.load(reward_path).astype(np.float32)
    return reward

def load_action(ep: h5py.File): 
    action_keys = [
        'action/left_arm', 
        'action/left_gripper', 
        'action/mobile_base', 
        'action/right_arm', 
        'action/right_gripper', 
        'action/torso'
    ]   
    actions = [ep[k][:] for k in action_keys]
    actions = np.concatenate(actions, axis=-1)
    return torch.from_numpy(actions)


def load_raw_episode_data(
    ep_path: Path,
) -> tuple[dict[str, np.ndarray], torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
    
    reward = load_reward(ep_path)
    with h5py.File(ep_path, "r") as ep:
        state = load_state(ep)
        action = load_action(ep)

        imgs_per_cam = load_raw_images_per_camera(
            ep,
            [
                "head",
                "left_wrist",
                "right_wrist",
            ],
        )

    return imgs_per_cam, state, action, reward


def populate_dataset(
    dataset: LeRobotDataset,
    hdf5_files: dict[str, list[Path]],
    horizon: int=20,
) -> LeRobotDataset:
    all_episodes = 0 
    for id in hdf5_files.keys():
        task, h5files = hdf5_files[id]
        cur_len = len(h5files)
        print(f"Task {task} size: {cur_len}")
        all_episodes += cur_len
    
    print("Num of episodes in total:", all_episodes)

    with tqdm.tqdm(total=all_episodes) as pbar:
        for id in hdf5_files.keys():
            task, h5files = hdf5_files[id]
            count = 0
            for ep_idx in range(len(h5files)):
                ep_path = h5files[ep_idx]
                if count > 50:
                    break
                print(ep_path)
                imgs_per_cam, state, action, reward = load_raw_episode_data(ep_path)
                num_frames = state.shape[0]

                if reward.shape[0] != num_frames:
                    print(f"mismatch episode {ep_path}")
                    continue

                for i in range(num_frames):
                    next_i = min(num_frames-1, i+horizon)
                    if i == num_frames - 1:
                        terminal = True
                    else:
                        terminal = False
                    frame = {
                        "observation.state": np.float32(state[i]),
                        "observation.next_state": np.float32(state[next_i]),
                        "action": np.float32(action[i]),
                        "next_action": np.float32(action[next_i]),
                        "reward": np.float32(reward[i:i+1]),
                        "terminal": np.bool_([terminal]),
                        "task": task,
                    }

                    for camera, img_array in imgs_per_cam.items():
                        frame[f"observation.images.{camera}"] = img_array[i]
                        frame[f"observation.images.next_{camera}"] = img_array[next_i]

                    dataset.add_frame(frame)
                print(f"Episode: {ep_idx}  Return={reward.sum()}  Reward Max={reward.max()} Reward Min={reward.min()}, Terminal={terminal}")
                dataset.save_episode()
                pbar.update(1)
                count += 1

    return dataset


def port_r1(
    raw_dirs: list[Path],
    repo_id: str,
    tasks: list[str] = ["DEBUG"],
    num_episodes: int = -1,
    *,
    push_to_hub: bool = False,
    success_only: bool = False,
    mode: Literal["video", "image"] = "image",
    dataset_config: DatasetConfig = DEFAULT_DATASET_CONFIG,
):
    if (HF_LEROBOT_HOME / repo_id).exists():
        shutil.rmtree(HF_LEROBOT_HOME / repo_id)

    postfix = 'success' if success_only else ''
    total_hdf5files = {}
    for id, (task, raw_dir) in enumerate(zip(tasks, raw_dirs)):
        raw_hdf5_files = sorted(raw_dir.glob(f"collected_data_*{postfix}.h5"))
        if num_episodes > 0:
            # use part of the data
            assert len(raw_hdf5_files) >= num_episodes
            hdf5_files = np.random.choice(raw_hdf5_files, num_episodes, replace=False)
        else:
            # use all data
            hdf5_files = raw_hdf5_files
        total_hdf5files[id] = [task, hdf5_files]

    # if num_episodes > 0:
    #     assert len(hdf5_files) >= num_episodes
    #     hdf5_files = np.random.choice(hdf5_files, num_episodes, replace=False)

    dataset = create_empty_dataset(
        repo_id,
        robot_type="r1",
        mode=mode,
        has_effort=False,
        has_velocity=False,
        dataset_config=dataset_config,
    )
    dataset = populate_dataset(
        dataset,
        total_hdf5files
    )

    if push_to_hub:
        dataset.push_to_hub()


if __name__ == "__main__":
    tyro.cli(port_r1)
