#!/usr/bin/env python3
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0

import argparse
import numpy as np
import cv2
import matplotlib.pyplot as plt


from minimal_client import RobotInferenceClient


def main():
    parser = argparse.ArgumentParser(description="DVRK Client for GR00T Inference Service")
    parser.add_argument("--host", type=str, default="localhost", help="Host address for the server")
    parser.add_argument("--port", type=int, default=5555, help="Port number for the server")
    parser.add_argument("--image_path", type=str, help="Path to sample images (optional)")
    parser.add_argument("--instruction", type=str, default="needle pickup", 
                        help="Task instruction for the robot")
    parser.add_argument("--stats_path", type=str, 
                        default="assets/final_suturing_2x_downsample/stats.json",
                        help="Path to statistics JSON file for de-normalization")
    args = parser.parse_args()

    # Create a client connection to the server
    print(f"Connecting to server at {args.host}:{args.port}...")
    client = RobotInferenceClient(host=args.host, port=args.port, stats_path=args.stats_path)

    # Get modality config from server to understand what it expects
    print("Retrieving modality configuration from server...")
    modality_configs = client.get_modality_config()
    print("Available modality configurations:")
    for key in modality_configs.keys():
        print(f"  - {key}")

    # Create sample observations
    print("Creating sample observations...")
    obs = create_sample_observations(args.image_path, args.instruction)
    
    # Print observation shapes
    print("\nObservation shapes:")
    for key, value in obs.items():
        if isinstance(value, np.ndarray):
            print(f"  - {key}: {value.shape}")
        else:
            print(f"  - {key}: {value}")

    # Get action from server
    print("\nSending observations to server and getting actions...")
    action = client.get_action(obs)

    # Print action shapes
    print("\nReceived actions (already denormalized):")
    for key, value in action.items():
        print(f"  - {key}: {value.shape}")
    
    # Visualize the observations and actions if possible
    visualize_data(obs, action, action)

def create_sample_observations(image_path=None, instruction="grasp the tissue"):
    """
    Create sample observations for the DVRK robot.
    
    Args:
        image_path: Optional path to load real images
        instruction: Task instruction
        
    Returns:
        Dictionary of observations
    """
    # Image dimensions based on DVRK configuration
    main_img_shape = (1, 540, 960, 3)  # (batch, height, width, channels)
    wrist_img_shape = (1, 480, 640, 3)  # (batch, height, width, channels)
    
    # Create or load images
    if image_path:
        try:
            # Try to load images from the provided path
            main_img = cv2.imread(f"{image_path}/main.jpg")
            endo_psm1_img = cv2.imread(f"{image_path}/endo_psm1.jpg")
            endo_psm2_img = cv2.imread(f"{image_path}/endo_psm2.jpg")
            
            # Convert BGR to RGB
            main_img = cv2.cvtColor(main_img, cv2.COLOR_BGR2RGB)
            endo_psm1_img = cv2.cvtColor(endo_psm1_img, cv2.COLOR_BGR2RGB)
            endo_psm2_img = cv2.cvtColor(endo_psm2_img, cv2.COLOR_BGR2RGB)
            
            # Resize if needed
            main_img = cv2.resize(main_img, (main_img_shape[2], main_img_shape[1]))
            endo_psm1_img = cv2.resize(endo_psm1_img, (wrist_img_shape[2], wrist_img_shape[1]))
            endo_psm2_img = cv2.resize(endo_psm2_img, (wrist_img_shape[2], wrist_img_shape[1]))
            
            # Add batch dimension
            main_img = np.expand_dims(main_img, axis=0)
            endo_psm1_img = np.expand_dims(endo_psm1_img, axis=0)
            endo_psm2_img = np.expand_dims(endo_psm2_img, axis=0)
            
        except Exception as e:
            print(f"Error loading images from {image_path}: {e}")
            print("Using random images instead.")
            main_img = np.random.randint(0, 256, main_img_shape, dtype=np.uint8)
            endo_psm1_img = np.random.randint(0, 256, wrist_img_shape, dtype=np.uint8)
            endo_psm2_img = np.random.randint(0, 256, wrist_img_shape, dtype=np.uint8)
    else:
        # Generate random images
        main_img = np.random.randint(0, 256, main_img_shape, dtype=np.uint8)
        endo_psm1_img = np.random.randint(0, 256, wrist_img_shape, dtype=np.uint8)
        endo_psm2_img = np.random.randint(0, 256, wrist_img_shape, dtype=np.uint8)
    
    # Create state vectors (zeros as in the dataset)
    # The state vectors for PSM1 and PSM2 are 8-dimensional each
    state_psm1 = np.zeros((1, 8), dtype=np.float32)
    state_psm2 = np.zeros((1, 8), dtype=np.float32)
    
    # Create the observation dictionary
    obs = {
        "video.main": main_img,
        "video.endo_psm1": endo_psm1_img,
        "video.endo_psm2": endo_psm2_img,
        "state.psm1": state_psm1,
        "state.psm2": state_psm2,
        "annotation.human.task_description": [instruction],
    }
    
    return obs

def visualize_data(obs, action, denormalized_action=None):
    """
    Visualize the observations and actions.
    
    Args:
        obs: Dictionary of observations
        action: Dictionary of normalized actions
        denormalized_action: Dictionary of de-normalized actions (optional)
    """
    try:
        # Create a figure with subplots for the images
        fig, axes = plt.subplots(1, 3, figsize=(15, 5))
        
        # Plot the images
        axes[0].imshow(obs["video.main"][0])
        axes[0].set_title("Main View")
        axes[0].axis("off")
        
        axes[1].imshow(obs["video.endo_psm1"][0])
        axes[1].set_title("PSM1 Endoscopic View")
        axes[1].axis("off")
        
        axes[2].imshow(obs["video.endo_psm2"][0])
        axes[2].set_title("PSM2 Endoscopic View")
        axes[2].axis("off")
        
        plt.tight_layout()
        
        # Create figures for action visualization
        if denormalized_action:
            # Two rows of plots: normalized and denormalized
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            
            # Plot normalized PSM1 actions
            action_psm1 = action["action.psm1"]
            timesteps = min(16, action_psm1.shape[0])
            for i in range(min(7, action_psm1.shape[1])):
                axes[0, 0].plot(range(timesteps), action_psm1[:timesteps, i], 
                         label=f"Dim {i}")
            axes[0, 0].set_title("PSM1 Actions (Normalized)")
            axes[0, 0].set_xlabel("Timestep")
            axes[0, 0].set_ylabel("Action Value")
            axes[0, 0].legend()
            
            # Plot normalized PSM2 actions
            action_psm2 = action["action.psm2"]
            timesteps = min(16, action_psm2.shape[0])
            for i in range(min(7, action_psm2.shape[1])):
                axes[0, 1].plot(range(timesteps), action_psm2[:timesteps, i], 
                         label=f"Dim {i}")
            axes[0, 1].set_title("PSM2 Actions (Normalized)")
            axes[0, 1].set_xlabel("Timestep")
            axes[0, 1].set_ylabel("Action Value")
            axes[0, 1].legend()
            
            # Plot denormalized PSM1 actions
            denorm_action_psm1 = denormalized_action["action.psm1"]
            for i in range(min(7, denorm_action_psm1.shape[1])):
                axes[1, 0].plot(range(timesteps), denorm_action_psm1[:timesteps, i], 
                         label=f"Dim {i}")
            axes[1, 0].set_title("PSM1 Actions (De-normalized)")
            axes[1, 0].set_xlabel("Timestep")
            axes[1, 0].set_ylabel("Action Value")
            axes[1, 0].legend()
            
            # Plot denormalized PSM2 actions
            denorm_action_psm2 = denormalized_action["action.psm2"]
            for i in range(min(7, denorm_action_psm2.shape[1])):
                axes[1, 1].plot(range(timesteps), denorm_action_psm2[:timesteps, i], 
                         label=f"Dim {i}")
            axes[1, 1].set_title("PSM2 Actions (De-normalized)")
            axes[1, 1].set_xlabel("Timestep")
            axes[1, 1].set_ylabel("Action Value")
            axes[1, 1].legend()
            
            plt.tight_layout()
            plt.savefig("dvrk_actions_comparison.png")
            print("\nVisualization saved to 'dvrk_observations.png' and 'dvrk_actions_comparison.png'")
        else:
            # Original visualization for normalized actions only
            plt.figure(figsize=(12, 6))
            
            # Plot the first few action values for PSM1
            plt.subplot(1, 2, 1)
            action_psm1 = action["action.psm1"]
            timesteps = min(16, action_psm1.shape[0])
            for i in range(min(7, action_psm1.shape[1])):
                plt.plot(range(timesteps), action_psm1[:timesteps, i], 
                         label=f"Dim {i}")
            plt.title("PSM1 Actions")
            plt.xlabel("Timestep")
            plt.ylabel("Action Value")
            plt.legend()
            
            # Plot the first few action values for PSM2
            plt.subplot(1, 2, 2)
            action_psm2 = action["action.psm2"]
            timesteps = min(16, action_psm2.shape[0])
            for i in range(min(7, action_psm2.shape[1])):
                plt.plot(range(timesteps), action_psm2[:timesteps, i], 
                         label=f"Dim {i}")
            plt.title("PSM2 Actions")
            plt.xlabel("Timestep")
            plt.ylabel("Action Value")
            plt.legend()
            
            plt.tight_layout()
            plt.savefig("dvrk_actions.png")
            print("\nVisualization saved to 'dvrk_observations.png' and 'dvrk_actions.png'")
        
        # Save the observations figure
        fig.savefig("dvrk_observations.png")
        
        # Show the plots if running in an interactive environment
        plt.show()
        
    except Exception as e:
        print(f"Error visualizing data: {e}")

if __name__ == "__main__":
    main()
