# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Run this command to interactively debug:
PYTHONPATH=. python cosmos_predict1/diffusion/posttrain/datasets/dataset_3D.py

Adapted from:
https://github.com/bytedance/IRASim/blob/main/dataset/dataset_3D.py
"""

import json
import pickle
import random
import traceback
import warnings

import numpy as np
import torch

from cosmos_predict1.diffusion.training.datasets.dataset_3D import Dataset_3D
from cosmos_predict1.utils import log


class Dataset_3DBinary(Dataset_3D):
    def __init__(
        self,
        train_annotation_path,
        val_annotation_path,
        test_annotation_path,
        video_path,
        sequence_interval,
        num_frames,
        cam_ids,
        accumulate_action,
        video_size,
        val_start_frame_interval,
        debug=False,
        normalize=False,
        pre_encode=False,
        do_evaluate=False,
        load_t5_embeddings=False,
        load_action=True,
        mode="train",
    ):
        """Dataset class for loading 3D robot action-conditional data.

        This dataset loads robot trajectories consisting of RGB video frames, robot states
        (arm positions and binary gripper states), and computes relative actions between
        consecutive frames.
        """

        super().__init__(
            train_annotation_path=train_annotation_path,
            val_annotation_path=val_annotation_path,
            test_annotation_path=test_annotation_path,
            video_path=video_path,
            sequence_interval=sequence_interval,
            num_frames=num_frames,
            cam_ids=cam_ids,
            accumulate_action=accumulate_action,
            video_size=video_size,
            val_start_frame_interval=val_start_frame_interval,
            debug=debug,
            normalize=normalize,
            pre_encode=pre_encode,
            do_evaluate=do_evaluate,
            load_t5_embeddings=load_t5_embeddings,
            load_action=load_action,
            mode=mode,
        )

        log.info("Dataset_3DBinary: in this dataset, we binarize the gripper state to 0 or 1.")

    def _get_json_action(self, label, frame_ids):
        all_action = np.array(label["action"])
        actions = all_action[frame_ids[:-1]]
        return torch.from_numpy(actions)

    def __getitem__(self, index, cam_id=None, return_video=False):
        if self.mode != "train":
            np.random.seed(index)
            random.seed(index)

        try:
            sample = self.samples[index]
            ann_file = sample["ann_file"]
            frame_ids = sample["frame_ids"]
            with open(ann_file, "r") as f:
                label = json.load(f)
            arm_states, gripper_states = self._get_robot_states(label, frame_ids)
            actions = self._get_actions(arm_states, gripper_states, self.accumulate_action)
            actions *= self.c_act_scaler

            data = dict()
            if self.load_action:
                data["action"] = actions.float()
                json_action = self._get_json_action(label, frame_ids).float()
                json_action[:, :6] = data["action"][:, :6]
                data["action"] = json_action

            if self.pre_encode:
                raise NotImplementedError("Pre-encoded videos are not supported for this dataset.")
            else:
                video, cam_id = self._get_obs(label, frame_ids, cam_id, pre_encode=False)
                video = video.permute(1, 0, 2, 3)  # Rearrange from [T, C, H, W] to [C, T, H, W]
                data["video"] = video.to(dtype=torch.uint8)

            data["annotation_file"] = ann_file

            if "episode_id" in label:
                data["__key__"] = label["episode_id"]
            else:
                data["__key__"] = label["original_path"]

            # Just add these to fit the interface
            if self.load_t5_embeddings:
                t5_embedding_path = ann_file.replace(".json", ".pickle")
                with open(t5_embedding_path, "rb") as f:
                    data["t5_text_embeddings"] = torch.from_numpy(pickle.load(f)[0]).cuda()
            else:
                data["t5_text_embeddings"] = torch.zeros(512, 1024, dtype=torch.bfloat16).cuda()
            data["t5_text_mask"] = torch.ones(512, dtype=torch.int64).cuda()
            data["fps"] = 4
            data["image_size"] = 256 * torch.ones(4).cuda()  # TODO: Does this matter?
            data["num_frames"] = self.sequence_length
            data["padding_mask"] = torch.zeros(1, 256, 256).cuda()

            return data
        except Exception:
            warnings.warn(
                f"Invalid data encountered: {self.samples[index]['ann_file']}. Skipped "
                f"(by randomly sampling another sample in the same dataset)."
            )
            warnings.warn("FULL TRACEBACK:")
            warnings.warn(traceback.format_exc())
            self.wrong_number += 1
            print(self.wrong_number)
            return self[np.random.randint(len(self.samples))]


if __name__ == "__main__":
    dataset = Dataset_3DBinary(
        train_annotation_path="datasets/bridge/annotation/train",
        val_annotation_path="datasets/bridge/annotation/val",
        test_annotation_path="datasets/bridge/annotation/test",
        video_path="datasets/bridge/",
        sequence_interval=1,
        num_frames=2,
        cam_ids=[0],
        accumulate_action=False,
        video_size=[256, 320],
        val_start_frame_interval=1,
        mode="train",
        load_t5_embeddings=True,
    )

    indices = [0, 13, 200, -1]
    for idx in indices:
        print(
            (
                f"{idx=} "
                f"{dataset[idx]['video'].sum()=}\n"
                f"{dataset[idx]['video'].shape=}\n"
                f"{dataset[idx]['video_name']=}\n"
                f"{dataset[idx]['action'].sum()=}\n"
                f"{dataset[idx]['json_action'].sum()=}\n"
                "---"
            )
        )

    from IPython import embed

    embed()
