# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import warnings

import numpy as np
import torch

from gr00t.data.dvrk_dataset import EpisodicDatasetDvrkGeneric
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.experiment.data_config import DATA_CONFIG_MAP
from gr00t.model.policy import BasePolicy, Gr00tPolicy
from gr00t.utils.eval import calc_mse_for_single_trajectory
from gr00t.data.schema import EmbodimentTag

warnings.simplefilter("ignore", category=FutureWarning)

"""
Example command:

python scripts/eval_policy.py --host localhost --port 5555 --plot
    --modality_keys right_arm right_hand
    --steps 250
    --trajs 1000
    --action_horizon 16
    --video_backend decord
    --dataset_path demo_data/robot_sim.PickNPlace/
    --embodiment_tag gr1
    --data_config gr1_arms_waist
provide --model_path to load up the model checkpoint in this script.


python scripts/eval_policy.py \
    --plot \
    --data_config dvrk \
    --steps 150 \
    --trajs 1 \
    --action_horizon 16 \
    --model_path checkpoints/suturing_downsample_2/checkpoint-42000 \
    --modality_keys psm1 psm2
"""

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--plot", action="store_true", help="plot images")
    parser.add_argument("--modality_keys", nargs="+", type=str, default=["right_arm", "right_hand"])
    parser.add_argument(
        "--data_config",
        type=str,
        default="dvrk",
        choices=list(DATA_CONFIG_MAP.keys()),
        help="data config name",
    )
    parser.add_argument("--steps", type=int, default=150, help="number of steps to run")
    parser.add_argument("--trajs", type=int, default=1, help="trajectories to run")
    parser.add_argument("--action_horizon", type=int, default=16)
    parser.add_argument("--video_backend", type=str, default="decord")
    parser.add_argument("--dataset_path", type=str, default="demo_data/robot_sim.PickNPlace/")
    parser.add_argument(
        "--embodiment_tag",
        type=str,
        help="The embodiment tag for the model.",
        default="new_embodiment",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default=None,
        help="[Optional] Path to the model checkpoint directory, this will disable client server mode.",
    )
    args = parser.parse_args()

    data_config = DATA_CONFIG_MAP[args.data_config]

    modality_config = data_config.modality_config()
    modality_transform = data_config.transform()

    policy: BasePolicy = Gr00tPolicy(
        model_path=args.model_path,
        modality_config=modality_config,
        modality_transform=modality_transform,
        embodiment_tag=args.embodiment_tag,
        device="cuda" if torch.cuda.is_available() else "cpu",
    )

    all_gt_actions = []
    all_pred_actions = []

    # Get the supported modalities for the policy
    modality = policy.get_modality_config()
    print(modality)

    # 1.1 modality configs and transforms
    data_config_cls = DATA_CONFIG_MAP[args.data_config]
    modality_configs = data_config_cls.modality_config()
    transforms = data_config_cls.transform()

    robot_base_dir_list = [
        "your/path/to/dvrk_data",
    ]

    # Create the dataset
    dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dir_list=robot_base_dir_list,
        action_horizon = args.action_horizon,
        cutting_action_pad_size = 10,
        transforms=transforms,
        embodiment_tag=EmbodimentTag.NEW_EMBODIMENT,
        downsample_factor=2,
        batch_size=1
    )

    # print(len(dataset))
    # # Make a prediction
    # obs = dataset[0]
    # for k, v in obs.items():
    #     if isinstance(v, np.ndarray):
    #         print(k, v.shape)
    #     else:
    #         print(k, v)

    # for k, v in dataset.get_step_data(0, 0).items():
    #     if isinstance(v, np.ndarray):
    #         print(k, v.shape)
    #     else:
    #         print(k, v)

    # print("Total trajectories:", len(dataset.trajectory_lengths))
    # print("All trajectories:", dataset.trajectory_lengths)
    # print("Running on all trajs with modality keys:", args.modality_keys)

    all_mse = []
    for traj_id in range(args.trajs):
        print("Running trajectory:", traj_id)
        mse = calc_mse_for_single_trajectory(
            policy,
            dataset,
            1,
            modality_keys=args.modality_keys,
            steps=args.steps,
            action_horizon=8,
            plot=args.plot,
        )
        print("MSE:", mse)
        all_mse.append(mse)
    print("Average MSE across all trajs:", np.mean(all_mse))
    print("Done")
    exit()
