"""
test_client.py

Tests the OpenVLA deployment server using direct, low-level access to DVRK dataset.
Randomly samples from the dataset and computes MSE between predictions and ground truth.

Example usage:
python vla-scripts/test_client_new.py \
    --data_root suturing_eval \
    --num_samples 10 \
    --save_plots
"""

import argparse
import json
import numpy as np
import requests
import random
import time
from pathlib import Path
import json_numpy
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

from prismatic.vla.datasets.dvrk_dataset import EpisodicDatasetDvrkGeneric
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX



def get_action_from_server(observation, server_endpoint):
    """Send observation to server and get predicted actions."""
    try:
        start_time = time.time()
        response = requests.post(
            server_endpoint,
            json=observation,
            timeout=30
        )
        end_time = time.time()
        print(f"Time taken to get action from server: {end_time - start_time:.6f} seconds")
        response.raise_for_status()
        return json_numpy.loads(response.text)
    except Exception as e:
        print(f"Error getting action from server: {e}")
        return None


def compute_mse(predictions, ground_truth, stats):
    """Compute Mean Squared Error between predictions and ground truth."""
    # re-normalize predictions and ground truth with quantile min max normalization
    q01 = stats["norm_stats"]["action"]["q01"]
    q99 = stats["norm_stats"]["action"]["q99"]
    predictions = np.array(predictions).astype(np.float32).copy()
    ground_truth = np.array(ground_truth).astype(np.float32).copy()
    
    for i in range(predictions.shape[1]):
        if q99[i] > q01[i]:
            predictions[:, i] = 2 * (predictions[:, i] - q01[i]) / (q99[i] - q01[i] + 1e-6) - 1
            ground_truth[:, i] = 2 * (ground_truth[:, i] - q01[i]) / (q99[i] - q01[i] + 1e-6) - 1
        else:
            predictions[:, i] = 0  # If no variation, set to 0
            ground_truth[:, i] = 0  # If no variation, set to 0

    return np.mean((predictions - ground_truth) ** 2)


def prepare_observation(data_sample):
    """Prepare observation dictionary to send to server."""
    observation = {
        "full_image": data_sample["observation"]["image_primary"].tolist(),
        "left_wrist_image": data_sample["observation"]["wrist_left"].tolist(),
        "right_wrist_image": data_sample["observation"]["wrist_right"].tolist(),
        "instruction": data_sample["task"]["language_instruction"].decode(), # Ensure string
    }
    return observation


def compute_trajectory(starting_pose, actions):
    """
    Computes the absolute trajectory by accumulating relative actions.
    
    Args:
        starting_pose: Initial pose of the arm
        actions: List of relative actions (dx, dy, dz, droll, dpitch, dyaw)
        
    Returns:
        List of poses representing the trajectory
    """
    trajectory = [starting_pose]
    current_pose = np.array(starting_pose)
    
    for action in actions:
        # For each action, add the relative changes to the current pose
        action = np.array(action)
        
        # First 3 elements are position changes (dx, dy, dz)
        current_pose[:3] += action[:3]
        
        # Next 3 elements are orientation changes (droll, dpitch, dyaw)
        current_pose[3:6] += action[3:6]
        
        # Append the new pose to the trajectory
        trajectory.append(current_pose.copy())
        
    return trajectory

def plot_actions_psm2(qpos_psm2, actions_psm2, gt_actions_psm2, save_plot=False, filename="psm2_trajectory.png"):
        factor = 1000
        fig = plt.figure()
        ax = plt.axes(projection='3d')
        print(f"actions_psm2 shape: {actions_psm2.shape}")
        print(f"qpos_psm2 shape: {qpos_psm2.shape}")
        ax.scatter(actions_psm2[:, 0] + qpos_psm2[0] * factor, actions_psm2[:, 1] + qpos_psm2[1] * factor, actions_psm2[:, 2] + qpos_psm2[2] * factor, c ='r', label = 'psm2 trajectory')
        ax.scatter(gt_actions_psm2[:, 0] + qpos_psm2[0] * factor, gt_actions_psm2[:, 1] + qpos_psm2[1] * factor, gt_actions_psm2[:, 2] + qpos_psm2[2] * factor, c ='g', label = 'gt psm2 trajectory')
        # ax.scatter(qpos_psm2[0]*factor, qpos_psm2[1]*factor, qpos_psm2[2]*factor, c = 'b', marker="*" , s = 10, label = 'Current psm2 position')
        ax.set_xlabel('X (mm)')
        ax.set_ylabel('Y (mm)')
        ax.set_zlabel('Z (mm)')
        n_bins = 7
        ax.legend()
        ax.xaxis.set_major_locator(plt.MaxNLocator(n_bins))
        ax.yaxis.set_major_locator(plt.MaxNLocator(n_bins))
        ax.zaxis.set_major_locator(plt.MaxNLocator(n_bins))
        if save_plot:
            plt.savefig(filename)
        plt.close(fig)

def visualize_trajectory(trajectory, ground_truth_trajectory=None, title=None):
    """
    Visualizes the trajectory in 3D space.
    
    Args:
        trajectory: List of poses (position + orientation)
        ground_truth_trajectory: Optional ground truth trajectory for comparison
        title: Optional title for the plot
    """
    fig = plt.figure(figsize=(10, 8))
    ax = fig.add_subplot(111, projection='3d')
    
    # Extract positions from trajectory
    positions = np.array([pose[:3] for pose in trajectory])
    
    # Plot the predicted trajectory
    ax.plot(positions[:, 0], positions[:, 1], positions[:, 2], 'b-', label='Predicted')
    ax.scatter(positions[0, 0], positions[0, 1], positions[0, 2], c='g', marker='o', s=50, label='Start')
    ax.scatter(positions[-1, 0], positions[-1, 1], positions[-1, 2], c='r', marker='x', s=50, label='End')
    
    # Plot ground truth trajectory if provided
    if ground_truth_trajectory is not None:
        gt_positions = np.array([pose[:3] for pose in ground_truth_trajectory])
        ax.plot(gt_positions[:, 0], gt_positions[:, 1], gt_positions[:, 2], 'g--', label='Ground Truth')
    
    # Set labels and title
    ax.set_xlabel('X position')
    ax.set_ylabel('Y position')
    ax.set_zlabel('Z position')
    
    if title:
        ax.set_title(title)
    else:
        ax.set_title('Arm Trajectory')
    
    ax.legend()
    plt.tight_layout()
    return fig

def main():
    parser = argparse.ArgumentParser(description="Low-level test client for OpenVLA server")
    parser.add_argument("--data_root", type=str, required=True, help="Path to DVRK dataset")
    parser.add_argument("--server_url", type=str, default="127.0.0.1", help="Server URL")
    parser.add_argument("--port", type=int, default=8777, help="Server port")
    parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to test")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    parser.add_argument("--visualize", action="store_true", help="Visualize trajectories")
    parser.add_argument("--save_plots", action="store_true", help="Save trajectory plots to files")
    parser.add_argument("--plot_dir", type=str, default="trajectory_plots", help="Directory to save plots")
    args = parser.parse_args()

    # Set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)

    # Set up server endpoint
    server_endpoint = f"http://{args.server_url}:{args.port}/act"
    print(f"Using server endpoint: {server_endpoint}")

    # Create plot directory if needed
    if args.save_plots:
        plot_dir = Path(args.plot_dir)
        plot_dir.mkdir(exist_ok=True, parents=True)
        print(f"Saving plots to {plot_dir}")

    with open("assets/suturing_1-9_w_no_throw_for_1-2/stats.json", "r") as f:
        stats = json.load(f)

    # Initialize dataset without batch transform
    print(f"Loading DVRK dataset from {args.data_root}")
    dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dir_list=[args.data_root],
        action_horizon=NUM_ACTIONS_CHUNK,  # Typical horizon length
        batch_transform=None,  # No batch transform
        batch_size=1,
        image_aug=False,  # No augmentation for testing
        skip_images=False,  # Need images for prediction
        norm_stats=stats
    )

    # Get total dataset length
    total_length = len(dataset.flattened_indices)
    print(f"Dataset contains {total_length} samples")

    # Initialize tracking variables
    all_mses = []
    sample_count = 0
    sample_indices = random.sample(range(total_length), min(args.num_samples, total_length))

    # Process random samples
    print(f"Testing server with {len(sample_indices)} random samples...")
    for idx in sample_indices:
        # Fetch raw sample directly using private method
        try:
            data_sample = dataset.__fetch_idx__(idx)
        except Exception as e:
            print(f"Error fetching sample at index {idx}: {e}")
            continue

        # Starting pose
        starting_pose = data_sample["arm_state"]
        
        # Extract ground truth actions
        ground_truth_actions = data_sample["action"]
        
        # Prepare observation for server
        observation = prepare_observation(data_sample)
        
        # Get prediction from server
        prediction = get_action_from_server(observation, server_endpoint)

        if prediction is None:
            print(f"Warning: Failed to get prediction for sample {idx}, skipping")
            continue

        # Compute MSE
        mse = compute_mse(prediction, ground_truth_actions, stats)
        all_mses.append(mse)

        print(f"Starting pose shape: {starting_pose.shape}, max value across 14 dims (axis=1): {np.max(starting_pose, axis=0)}")
        prediction = np.array(prediction)
        print(f"Prediction shape: {prediction.shape}, max value across 14 dims (axis=1): {np.max(prediction, axis=0)}")
        print(f"Ground truth actions shape: {ground_truth_actions.shape}, max value across 14 dims (axis=1): {np.max(ground_truth_actions, axis=0)}")

        prediction_mean = np.mean(prediction, axis=0)
        ground_truth_mean = np.mean(ground_truth_actions, axis=0)
        prediction_mean_axis_1 = np.mean(prediction, axis=1)
        ground_truth_mean_axis_1 = np.mean(ground_truth_actions, axis=1)
        print(f"Prediction mean (axis=0): {prediction_mean}, max value across 14 dims (axis=1): {np.max(prediction_mean, axis=0)}")
        print(f"Ground truth mean (axis=0): {ground_truth_mean}, max value across 14 dims (axis=1): {np.max(ground_truth_mean, axis=0)}")
        print(f"Prediction mean (axis=1): {prediction_mean_axis_1}, max value across 14 dims (axis=1): {np.max(prediction_mean_axis_1, axis=0)}")
        print(f"Ground truth mean (axis=1): {ground_truth_mean_axis_1}, max value across 14 dims (axis=1): {np.max(ground_truth_mean_axis_1, axis=0)}")
        
        # Visualize trajectories if requested
        if args.visualize or args.save_plots:
            plot_actions_psm2(starting_pose[0][:3], prediction[:, :3], ground_truth_actions[:, :3], save_plot=args.save_plots, filename=plot_dir / f"trajectory_final_{idx}_gt.png")
            # Compute trajectories
            # pred_trajectory = compute_trajectory(starting_pose, prediction)
            # gt_trajectory = compute_trajectory(starting_pose, ground_truth_actions)
            
            # # Create plot
            # fig = visualize_trajectory(
            #     pred_trajectory, 
            #     gt_trajectory,
            #     title=f"Sample {idx} - MSE: {mse:.6f}"
            # )
            
            # if args.save_plots:
            #     plt.savefig(plot_dir / f"trajectory_sample_{idx}.png")
            #     plt.close(fig)
            
            # if args.visualize:
            #     plt.show()
        
        sample_count += 1
        if sample_count % 10 == 0:
            print(f"Processed {sample_count}/{len(sample_indices)} samples")
            print(f"Current average MSE: {np.mean(all_mses):.6f}")
        
        # Small delay to avoid overwhelming the server
        time.sleep(0.1)

    # Report final results
    if all_mses:
        print("\nTest Results:")
        print(f"Average MSE across {len(all_mses)} samples: {np.mean(all_mses):.6f}")
        print(f"Min MSE: {np.min(all_mses):.6f}")
        print(f"Max MSE: {np.max(all_mses):.6f}")
        print(f"Median MSE: {np.median(all_mses):.6f}")
        print(f"Standard Deviation: {np.std(all_mses):.6f}")
    else:
        print("No valid predictions were made.")


if __name__ == "__main__":
    main()