import numpy as np
import torch
import os
import argparse

# OpenPI specific imports
from openpi.training import config as openpi_config
from openpi.policies import policy_config
from openpi.shared import download
from openpi.training.dvrk_dataset import EpisodicDatasetDvrkGeneric
from openpi.utils.eval_utils import calc_mse_for_single_trajectory_pi0 # Import the new function
import openpi.transforms as _transforms

# Ensure the project root is in the Python path if running script directly
# import sys
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))


def main(args):
    # --- Configuration ---
    # Replace with your actual config name and checkpoint
    # Option 1: Use a predefined config name (if it exists in openpi/training/config.py)
    # config_name = "pi0_base" # Or a fine-tuned dvrk config name if you have one
    # Option 2: Specify checkpoint path directly (local or S3)
    # checkpoint_path = "s3://openpi-assets/checkpoints/pi0_base" # Example base model
    # checkpoint_path = "/path/to/your/local/finetuned/checkpoint"
    if not args.config_name and not args.checkpoint_path:
        raise ValueError("Either --config-name or --checkpoint-path must be provided.")

    if args.checkpoint_path:
         # If checkpoint path is given, try to infer config or use a default
         print(f"Using checkpoint path: {args.checkpoint_path}")
         # This assumes the config is stored with the checkpoint or uses a base config
         # You might need to manually load/specify the config if it's not standard
         # For simplicity, let's assume a base config can be used if config_name is not provided
         config_name = args.config_name if args.config_name else "pi0_base" # Default guess
         print(f"Using config name: {config_name} (may be default)")
         cfg = openpi_config.get_config(config_name)
         checkpoint_dir = download.maybe_download(args.checkpoint_path)
    else:
         # Use config name to define both config and checkpoint S3 path (standard way)
         print(f"Using config name: {args.config_name}")
         cfg = openpi_config.get_config(args.config_name)
         # Construct default S3 path based on config name (adjust if your naming differs)
         s3_path = f"s3://openpi-assets/checkpoints/{args.config_name}"
         print(f"Attempting to download checkpoint from: {s3_path}")
         checkpoint_dir = download.maybe_download(s3_path)


    # --- Load Policy ---
    print("Loading policy...")
    # Ensure CUDA is available if specified
    if "cuda" in args.device and not torch.cuda.is_available():
        print("Warning: CUDA specified but not available. Using CPU.")
        device = "cpu"
    else:
        device = args.device

    repack_transform = _transforms.Group(
        inputs=[
            _transforms.RepackTransform(
                {
                    "left_image": "observation.images.left",
                    # "right_image": "observation.images.right", # Don't use 
                    "endo_psm1_image": "observation.images.endo_psm1",
                    "endo_psm2_image": "observation.images.endo_psm2",
                    "state": "observation.state",
                    "actions": "action",
                    "prompt": "prompt",
                    "actions_is_pad": "action_is_pad",
                }
            )
        ]
    )
    # Note: openpi policies might handle device internally or via config. Adjust if needed.
    policy = policy_config.create_trained_policy(
        train_config=cfg, 
        checkpoint_dir=checkpoint_dir,
        repack_transforms=repack_transform,
    )
    print("Policy loaded.")
    # You might need to explicitly move the policy to the device if create_trained_policy doesn't
    # policy.to(device) # Example if policy is a PyTorch module


    # --- Load Dataset ---
    # Use the dataset paths defined in the dataset class or override here
    # Using the hardcoded paths from the dataset class for now
    print("Loading dataset...")

    eval_dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dirs=None, # Let it use the default hardcoded ones
        action_horizon=args.action_horizon, # Use action horizon from args
        cutting_action_pad_size=10, # Keep default or make arg
        # No transforms needed usually, dataset handles internal processing
    )
    print(f"Dataset loaded with {len(eval_dataset.episode_list)} episodes.")
    print(f"Dataset action horizon: {eval_dataset.action_horizon}")
    if eval_dataset.action_horizon != args.action_horizon:
         print(f"Warning: Specified action_horizon ({args.action_horizon}) differs from dataset's default ({eval_dataset.action_horizon}). Using {args.action_horizon} for inference interval.")


    # --- Run Evaluation ---
    if args.traj_id >= len(eval_dataset.episode_list):
         print(f"Error: Trajectory ID {args.traj_id} is out of bounds for the dataset (0-{len(eval_dataset.episode_list)-1}).")
         return

    print(f"\nStarting evaluation for trajectory ID: {args.traj_id}")
    mse = calc_mse_for_single_trajectory_pi0(
        policy=policy,
        dataset=eval_dataset,
        traj_id=args.traj_id,
        steps=args.steps,
        action_horizon=args.action_horizon, # Use specified horizon for inference interval
        plot=args.plot,
        save_video=args.save_video,
        video_fps=args.video_fps,
        video_frame_key="observation.images.left" # Or make this an arg
    )

    print(f"\nEvaluation finished for trajectory ID: {args.traj_id}")
    print(f"Overall Action MSE: {mse}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluate OpenPI policy on dVRK dataset.")
    parser.add_argument("--config-name", type=str, default=None, help="Name of the policy config (e.g., 'pi0_base').")
    parser.add_argument("--checkpoint-path", type=str, default=None, help="Path to the policy checkpoint (local or S3 URI). Overrides config name for checkpoint loading.")
    parser.add_argument("--traj-id", type=int, required=True, help="Index of the trajectory (episode) to evaluate.")
    parser.add_argument("--action-horizon", type=int, default=10, help="Action horizon (inference interval). Should ideally match policy training.")
    parser.add_argument("--steps", type=int, default=300, help="Maximum number of steps to evaluate in the trajectory.")
    parser.add_argument("--plot", action="store_true", help="Generate and save trajectory plots.")
    parser.add_argument("--save-video", action="store_true", help="Save a video of the evaluation.")
    parser.add_argument("--video-fps", type=int, default=10, help="FPS for the saved video.")
    parser.add_argument("--device", type=str, default="cuda", help="Device for policy inference ('cuda' or 'cpu').")

    args = parser.parse_args()
    main(args) 