#!/usr/bin/env python

# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Use this script to convert your dataset into LeRobot dataset format and upload it to the Hugging Face hub,
or store it locally. LeRobot dataset format is lightweight, fast to load from, and does not require any
installation of neural net specific packages like pytorch, tensorflow, jax.

Example of how to download raw datasets, convert them into LeRobotDataset format, and push them to the hub:
```
python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/pusht_raw \
--raw-format pusht_zarr \
--repo-id lerobot/pusht

python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/xarm_lift_medium_raw \
--raw-format xarm_pkl \
--repo-id lerobot/xarm_lift_medium

python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/aloha_sim_insertion_scripted_raw \
--raw-format aloha_hdf5 \
--repo-id lerobot/aloha_sim_insertion_scripted

python lerobot/scripts/push_dataset_to_hub.py \
--raw-dir data/umi_cup_in_the_wild_raw \
--raw-format umi_zarr \
--repo-id lerobot/umi_cup_in_the_wild
```

**WARNING: Updating an existing dataset**

If you want to update an existing dataset, you need to change the `CODEBASE_VERSION` from `lerobot_dataset.py`
before running `push_dataset_to_hub.py`. This is especially useful if you introduce a breaking change
intentionally or not (i.e. something not backward compatible such as modifying the reward functions used,
deleting some frames at the end of an episode, etc.). That way, people running a previous version of the
codebase won't be affected by your change and backward compatibility is maintained.

For instance, Pusht has many versions to maintain backward compatibility between LeRobot codebase versions:
- [v1.0](https://huggingface.co/datasets/lerobot/pusht/tree/v1.0)
- [v1.1](https://huggingface.co/datasets/lerobot/pusht/tree/v1.1)
- [v1.2](https://huggingface.co/datasets/lerobot/pusht/tree/v1.2)
- [v1.3](https://huggingface.co/datasets/lerobot/pusht/tree/v1.3)
- [v1.4](https://huggingface.co/datasets/lerobot/pusht/tree/v1.4)
- [v1.5](https://huggingface.co/datasets/lerobot/pusht/tree/v1.5) <-- last version
- [main](https://huggingface.co/datasets/lerobot/pusht/tree/main)  <-- points to the last version

However, you will need to update the version of ALL the other datasets so that they have the new
`CODEBASE_VERSION` as a branch in their hugging face dataset repository. Don't worry, there is an easy way
that doesn't require to run `push_dataset_to_hub.py`. You can just "branch-out" from the `main` branch on HF
dataset repo by running this script which corresponds to a `git checkout -b` (so no copy or upload needed):

```python
import os

from huggingface_hub import create_branch, hf_hub_download
from huggingface_hub.utils._errors import RepositoryNotFoundError

from lerobot import available_datasets
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION

os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"  # makes it easier to see the print-out below

NEW_CODEBASE_VERSION = "v1.5"  # REPLACE THIS WITH YOUR DESIRED VERSION

for repo_id in available_datasets:
    # First check if the newer version already exists.
    try:
        hf_hub_download(
            repo_id=repo_id, repo_type="dataset", filename=".gitattributes", revision=NEW_CODEBASE_VERSION
        )
        print(f"Found existing branch for {repo_id}. Please contact a member of the core LeRobot team.")
        print("Exiting early")
        break
    except RepositoryNotFoundError:
        # Now create a branch.
        create_branch(repo_id, repo_type="dataset", branch=NEW_CODEBASE_VERSION, revision=CODEBASE_VERSION)
        print(f"{repo_id} successfully updated")

```

On the other hand, if you are pushing a new dataset, you don't need to worry about any of the instructions
above, nor to be compatible with previous codebase versions.
"""

import argparse
import json
import shutil
import warnings
from pathlib import Path
from typing import Any

import torch
from huggingface_hub import HfApi, create_branch
from safetensors.torch import save_file

from lerobot.common.datasets.compute_stats import compute_stats
from lerobot.common.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset
from lerobot.common.datasets.utils import flatten_dict


def get_from_raw_to_lerobot_format_fn(raw_format: str):
    if raw_format == "pusht_zarr":
        from lerobot.common.datasets.push_dataset_to_hub.pusht_zarr_format import from_raw_to_lerobot_format
    elif raw_format == "umi_zarr":
        from lerobot.common.datasets.push_dataset_to_hub.umi_zarr_format import from_raw_to_lerobot_format
    elif raw_format == "aloha_hdf5":
        from lerobot.common.datasets.push_dataset_to_hub.aloha_hdf5_format import from_raw_to_lerobot_format
    elif raw_format == "dora_parquet":
        from lerobot.common.datasets.push_dataset_to_hub.dora_parquet_format import from_raw_to_lerobot_format
    elif raw_format == "xarm_pkl":
        from lerobot.common.datasets.push_dataset_to_hub.xarm_pkl_format import from_raw_to_lerobot_format
    elif raw_format == "cam_png":
        from lerobot.common.datasets.push_dataset_to_hub.cam_png_format import from_raw_to_lerobot_format
    else:
        raise ValueError(
            f"The selected {raw_format} can't be found. Did you add it to `lerobot/scripts/push_dataset_to_hub.py::get_from_raw_to_lerobot_format_fn`?"
        )

    return from_raw_to_lerobot_format


def save_meta_data(
    info: dict[str, Any], stats: dict, episode_data_index: dict[str, list], meta_data_dir: Path
):
    meta_data_dir.mkdir(parents=True, exist_ok=True)

    # save info
    info_path = meta_data_dir / "info.json"
    with open(str(info_path), "w") as f:
        json.dump(info, f, indent=4)

    # save stats
    stats_path = meta_data_dir / "stats.safetensors"
    save_file(flatten_dict(stats), stats_path)

    # save episode_data_index
    episode_data_index = {key: torch.tensor(episode_data_index[key]) for key in episode_data_index}
    ep_data_idx_path = meta_data_dir / "episode_data_index.safetensors"
    save_file(episode_data_index, ep_data_idx_path)


def push_meta_data_to_hub(repo_id: str, meta_data_dir: str | Path, revision: str | None):
    """Expect all meta data files to be all stored in a single "meta_data" directory.
    On the hugging face repositery, they will be uploaded in a "meta_data" directory at the root.
    """
    api = HfApi()
    api.upload_folder(
        folder_path=meta_data_dir,
        path_in_repo="meta_data",
        repo_id=repo_id,
        revision=revision,
        repo_type="dataset",
    )


def push_videos_to_hub(repo_id: str, videos_dir: str | Path, revision: str | None):
    """Expect mp4 files to be all stored in a single "videos" directory.
    On the hugging face repositery, they will be uploaded in a "videos" directory at the root.
    """
    api = HfApi()
    api.upload_folder(
        folder_path=videos_dir,
        path_in_repo="videos",
        repo_id=repo_id,
        revision=revision,
        repo_type="dataset",
        allow_patterns="*.mp4",
    )


def push_dataset_to_hub(
    raw_dir: Path,
    raw_format: str,
    repo_id: str,
    push_to_hub: bool = True,
    local_dir: Path | None = None,
    fps: int | None = None,
    video: bool = True,
    batch_size: int = 32,
    num_workers: int = 8,
    episodes: list[int] | None = None,
    force_override: bool = False,
    cache_dir: Path = Path("/tmp"),
    tests_data_dir: Path | None = None,
):
    # Check repo_id is well formated
    if len(repo_id.split("/")) != 2:
        raise ValueError(
            f"`repo_id` is expected to contain a community or user id `/` the name of the dataset (e.g. 'lerobot/pusht'), but instead contains '{repo_id}'."
        )
    user_id, dataset_id = repo_id.split("/")

    # Robustify when `raw_dir` is str instead of Path
    raw_dir = Path(raw_dir)
    if not raw_dir.exists():
        raise NotADirectoryError(
            f"{raw_dir} does not exists. Check your paths or run this command to download an existing raw dataset on the hub:"
            f"python lerobot/common/datasets/push_dataset_to_hub/_download_raw.py --raw-dir your/raw/dir --repo-id your/repo/id_raw"
        )

    if local_dir:
        # Robustify when `local_dir` is str instead of Path
        local_dir = Path(local_dir)

        # Send warning if local_dir isn't well formated
        if local_dir.parts[-2] != user_id or local_dir.parts[-1] != dataset_id:
            warnings.warn(
                f"`local_dir` ({local_dir}) doesn't contain a community or user id `/` the name of the dataset that match the `repo_id` (e.g. 'data/lerobot/pusht'). Following this naming convention is advised, but not mandatory.",
                stacklevel=1,
            )

        # Check we don't override an existing `local_dir` by mistake
        if local_dir.exists():
            if force_override:
                shutil.rmtree(local_dir)
            else:
                raise ValueError(f"`local_dir` already exists ({local_dir}). Use `--force-override 1`.")

        meta_data_dir = local_dir / "meta_data"
        videos_dir = local_dir / "videos"
    else:
        # Temporary directory used to store images, videos, meta_data
        meta_data_dir = Path(cache_dir) / "meta_data"
        videos_dir = Path(cache_dir) / "videos"

    if raw_format is None:
        # TODO(rcadene, adilzouitine): implement auto_find_raw_format
        raise NotImplementedError()
        # raw_format = auto_find_raw_format(raw_dir)

    # convert dataset from original raw format to LeRobot format
    from_raw_to_lerobot_format = get_from_raw_to_lerobot_format_fn(raw_format)
    hf_dataset, episode_data_index, info = from_raw_to_lerobot_format(
        raw_dir, videos_dir, fps, video, episodes
    )

    lerobot_dataset = LeRobotDataset.from_preloaded(
        repo_id=repo_id,
        hf_dataset=hf_dataset,
        episode_data_index=episode_data_index,
        info=info,
        videos_dir=videos_dir,
    )
    stats = compute_stats(lerobot_dataset, batch_size, num_workers)

    if local_dir:
        hf_dataset = hf_dataset.with_format(None)  # to remove transforms that cant be saved
        hf_dataset.save_to_disk(str(local_dir / "train"))

    if push_to_hub or local_dir:
        # mandatory for upload
        save_meta_data(info, stats, episode_data_index, meta_data_dir)

    if push_to_hub:
        hf_dataset.push_to_hub(repo_id, revision="main")
        push_meta_data_to_hub(repo_id, meta_data_dir, revision="main")
        if video:
            push_videos_to_hub(repo_id, videos_dir, revision="main")
        create_branch(repo_id, repo_type="dataset", branch=CODEBASE_VERSION)

    if tests_data_dir:
        # get the first episode
        num_items_first_ep = episode_data_index["to"][0] - episode_data_index["from"][0]
        test_hf_dataset = hf_dataset.select(range(num_items_first_ep))
        episode_data_index = {k: v[:1] for k, v in episode_data_index.items()}

        test_hf_dataset = test_hf_dataset.with_format(None)
        test_hf_dataset.save_to_disk(str(tests_data_dir / repo_id / "train"))

        tests_meta_data = tests_data_dir / repo_id / "meta_data"
        save_meta_data(info, stats, episode_data_index, tests_meta_data)

        # copy videos of first episode to tests directory
        episode_index = 0
        tests_videos_dir = tests_data_dir / repo_id / "videos"
        tests_videos_dir.mkdir(parents=True, exist_ok=True)
        for key in lerobot_dataset.video_frame_keys:
            fname = f"{key}_episode_{episode_index:06d}.mp4"
            shutil.copy(videos_dir / fname, tests_videos_dir / fname)

    if local_dir is None:
        # clear cache
        shutil.rmtree(meta_data_dir)
        shutil.rmtree(videos_dir)

    return lerobot_dataset


def main():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--raw-dir",
        type=Path,
        required=True,
        help="Directory containing input raw datasets (e.g. `data/aloha_mobile_chair_raw` or `data/pusht_raw).",
    )
    # TODO(rcadene): add automatic detection of the format
    parser.add_argument(
        "--raw-format",
        type=str,
        required=True,
        help="Dataset type (e.g. `pusht_zarr`, `umi_zarr`, `aloha_hdf5`, `xarm_pkl`, `dora_parquet`).",
    )
    parser.add_argument(
        "--repo-id",
        type=str,
        required=True,
        help="Repositery identifier on Hugging Face: a community or a user name `/` the name of the dataset (e.g. `lerobot/pusht`, `cadene/aloha_sim_insertion_human`).",
    )
    parser.add_argument(
        "--local-dir",
        type=Path,
        help="When provided, writes the dataset converted to LeRobotDataset format in this directory  (e.g. `data/lerobot/aloha_mobile_chair`).",
    )
    parser.add_argument(
        "--push-to-hub",
        type=int,
        default=1,
        help="Upload to hub.",
    )
    parser.add_argument(
        "--fps",
        type=int,
        help="Frame rate used to collect videos. If not provided, use the default one specified in the code.",
    )
    parser.add_argument(
        "--video",
        type=int,
        default=1,
        help="Convert each episode of the raw dataset to an mp4 video. This option allows 60 times lower disk space consumption and 25 faster loading time during training.",
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        help="Batch size loaded by DataLoader for computing the dataset statistics.",
    )
    parser.add_argument(
        "--num-workers",
        type=int,
        default=8,
        help="Number of processes of Dataloader for computing the dataset statistics.",
    )
    parser.add_argument(
        "--episodes",
        type=int,
        nargs="*",
        help="When provided, only converts the provided episodes (e.g `--episodes 2 3 4`). Useful to test the code on 1 episode.",
    )
    parser.add_argument(
        "--force-override",
        type=int,
        default=0,
        help="When set to 1, removes provided output directory if it already exists. By default, raises a ValueError exception.",
    )
    parser.add_argument(
        "--tests-data-dir",
        type=Path,
        help=(
            "When provided, save tests artifacts into the given directory "
            "(e.g. `--tests-data-dir tests/data` will save to tests/data/{--repo-id})."
        ),
    )

    args = parser.parse_args()
    push_dataset_to_hub(**vars(args))


if __name__ == "__main__":
    main()
