import warnings
from dataclasses import dataclass, field
from typing import List, Literal
import numpy as np
import tyro
from gr00t.data.dataset import LeRobotSingleDataset
from gr00t.data.embodiment_tags import EMBODIMENT_TAG_MAPPING
from gr00t.eval.robot import RobotInferenceClient
from gr00t.experiment.data_config import load_data_config
from gr00t.model.policy import BasePolicy, Gr00tPolicy
from gr00t.utils.eval import calc_mse_for_single_trajectory
warnings.simplefilter("ignore", category=FutureWarning)
"""
Example command:
NOTE: provide --model_path to load up the model checkpoint in this script,
        else it will use the default host and port via RobotInferenceClient
python scripts/eval_policy.py --plot --model-path nvidia/GR00T-N1.5-3B
"""
@dataclass
class ArgsConfig:
    """Configuration for evaluating a policy."""
    host: str = "localhost"
    """Host to connect to."""
    port: int = 5555
    """Port to connect to."""
    plot: bool = False
    """Whether to plot the images."""
    modality_keys: List[str] = field(default_factory=lambda: ["right_arm", "left_arm"])
    """Modality keys to evaluate."""
    data_config: str = "fourier_gr1_arms_only"
    """
    Data config to use, e.g. so100, fourier_gr1_arms_only, unitree_g1, etc.
    Or a path to a custom data config file. e.g. "module:ClassName" format.
    See gr00t/experiment/data_config.py for more details.
    """
    steps: int = 150
    """Number of steps to evaluate."""
    trajs: int = 1
    """Number of trajectories to evaluate."""
    start_traj: int = 0
    """Start trajectory to evaluate."""
    action_horizon: int = None
    """Action horizon to evaluate. If None, will use the data config's action horizon."""
    video_backend: Literal["decord", "torchvision_av"] = "decord"
    """Video backend to use for various codec options. h264: decord or av: torchvision_av"""
    dataset_path: str = "demo_data/robot_sim.PickNPlace/"
    """Path to the dataset."""
    embodiment_tag: Literal[tuple(EMBODIMENT_TAG_MAPPING.keys())] = "gr1"
    """Embodiment tag to use."""
    model_path: str = None
    """Path to the model checkpoint."""
    denoising_steps: int = 4
    """Number of denoising steps to use."""
    save_plot_path: str = None
    """Path to save the plot."""
    plot_state: bool = False
    """Whether to plot the state."""
def main(args: ArgsConfig):
    data_config = load_data_config(args.data_config)
    if args.action_horizon is None:
        args.action_horizon = len(data_config.action_indices)
        print(f"Using action_horizon={args.action_horizon} from data config '{args.data_config}'")
    if args.model_path is not None:
        import torch
        modality_config = data_config.modality_config()
        modality_transform = data_config.transform()
        policy: BasePolicy = Gr00tPolicy(
            model_path=args.model_path,
            modality_config=modality_config,
            modality_transform=modality_transform,
            embodiment_tag=args.embodiment_tag,
            denoising_steps=args.denoising_steps,
            device="cuda" if torch.cuda.is_available() else "cpu",
        )
    else:
        policy: BasePolicy = RobotInferenceClient(host=args.host, port=args.port)
    modality = policy.get_modality_config()
    print("Current modality config: \n", modality)
    dataset = LeRobotSingleDataset(
        dataset_path=args.dataset_path,
        modality_configs=modality,
        video_backend=args.video_backend,
        video_backend_kwargs=None,
        transforms=None,
        embodiment_tag=args.embodiment_tag,
    )
    print(len(dataset))
    obs = dataset[0]
    for k, v in obs.items():
        if isinstance(v, np.ndarray):
            print(k, v.shape)
        else:
            print(k, v)
    for k, v in dataset.get_step_data(0, 0).items():
        if isinstance(v, np.ndarray):
            print(k, v.shape)
        else:
            print(k, v)
    print("Total trajectories:", len(dataset.trajectory_lengths))
    print("All trajectories:", dataset.trajectory_lengths)
    print("Running on all trajs with modality keys:", args.modality_keys)
    all_mse = []
    for traj_id in range(args.start_traj, args.start_traj + args.trajs):
        print("Running trajectory:", traj_id)
        mse = calc_mse_for_single_trajectory(
            policy,
            dataset,
            traj_id,
            modality_keys=args.modality_keys,
            steps=args.steps,
            action_horizon=args.action_horizon,
            plot=args.plot,
            plot_state=args.plot_state,
            save_plot_path=args.save_plot_path,
        )
        print("MSE:", mse)
        all_mse.append(mse)
    print("Average MSE across all trajs:", np.mean(all_mse))
    print("Done")
    exit()
if __name__ == "__main__":
    config = tyro.cli(ArgsConfig)
    main(config)
