"""
test_client.py

Tests the OpenVLA deployment server using direct, low-level access to DVRK dataset.
Randomly samples from the dataset and computes MSE between predictions and ground truth.

Example usage:
python vla-scripts/test_client.py \
    --data_root processed_suturing_data_zipped_pi \
    --num_samples 10
"""

import argparse
import json
import numpy as np
import requests
import random
import time
from pathlib import Path
import json_numpy

from prismatic.vla.datasets.dvrk_dataset import EpisodicDatasetDvrkGeneric
from prismatic.vla.constants import ACTION_DIM, ACTION_TOKEN_BEGIN_IDX, IGNORE_INDEX, NUM_ACTIONS_CHUNK, PROPRIO_DIM, STOP_INDEX



def get_action_from_server(observation, server_endpoint):
    """Send observation to server and get predicted actions."""
    try:
        start_time = time.time()
        response = requests.post(
            server_endpoint,
            json=observation,
            timeout=30
        )
        end_time = time.time()
        print(f"Time taken to get action from server: {end_time - start_time:.6f} seconds")
        response.raise_for_status()
        return json_numpy.loads(response.text)
    except Exception as e:
        print(f"Error getting action from server: {e}")
        return None


def compute_mse(predictions, ground_truth):
    """Compute Mean Squared Error between predictions and ground truth."""
    predictions = np.array(predictions).astype(np.float32)
    ground_truth = np.array(ground_truth).astype(np.float32)
    return np.mean((predictions - ground_truth) ** 2)


def prepare_observation(data_sample):
    """Prepare observation dictionary to send to server."""
    observation = {
        "full_image": data_sample["observation"]["image_primary"].tolist(),
        "left_wrist_image": data_sample["observation"]["wrist_left"].tolist(),
        "right_wrist_image": data_sample["observation"]["wrist_right"].tolist(),
        "instruction": data_sample["task"]["language_instruction"].decode(), # Ensure string
    }
    return observation


def main():
    parser = argparse.ArgumentParser(description="Low-level test client for OpenVLA server")
    parser.add_argument("--data_root", type=str, required=True, help="Path to DVRK dataset")
    parser.add_argument("--server_url", type=str, default="127.0.0.1", help="Server URL")
    parser.add_argument("--port", type=int, default=8777, help="Server port")
    parser.add_argument("--num_samples", type=int, default=100, help="Number of samples to test")
    parser.add_argument("--seed", type=int, default=42, help="Random seed")
    args = parser.parse_args()

    # Set random seed
    random.seed(args.seed)
    np.random.seed(args.seed)

    # Set up server endpoint
    server_endpoint = f"http://{args.server_url}:{args.port}/act"
    print(f"Using server endpoint: {server_endpoint}")

    # Initialize dataset without batch transform
    print(f"Loading DVRK dataset from {args.data_root}")
    dataset = EpisodicDatasetDvrkGeneric(
        robot_base_dir_list=[args.data_root],
        action_horizon=NUM_ACTIONS_CHUNK,  # Typical horizon length
        batch_transform=None,  # No batch transform
        batch_size=1,
        image_aug=False,  # No augmentation for testing
        skip_images=False,  # Need images for prediction
    )

    # Get total dataset length
    total_length = len(dataset.flattened_indices)
    print(f"Dataset contains {total_length} samples")

    # Initialize tracking variables
    all_mses = []
    sample_count = 0
    sample_indices = random.sample(range(total_length), min(args.num_samples, total_length))

    # Process random samples
    print(f"Testing server with {len(sample_indices)} random samples...")
    for idx in sample_indices:
        # Fetch raw sample directly using private method
        try:
            data_sample = dataset.__fetch_idx__(idx)
        except Exception as e:
            print(f"Error fetching sample at index {idx}: {e}")
            continue

        # Extract ground truth actions
        ground_truth_actions = data_sample["action"]
        
        # Prepare observation for server
        observation = prepare_observation(data_sample)
        
        # Get prediction from server
        prediction = get_action_from_server(observation, server_endpoint)
        if prediction is None:
            print(f"Warning: Failed to get prediction for sample {idx}, skipping")
            continue

        # Compute MSE
        mse = compute_mse(prediction, ground_truth_actions)
        all_mses.append(mse)
        
        sample_count += 1
        if sample_count % 10 == 0:
            print(f"Processed {sample_count}/{len(sample_indices)} samples")
            print(f"Current average MSE: {np.mean(all_mses):.6f}")
        
        # Small delay to avoid overwhelming the server
        time.sleep(0.1)

    # Report final results
    if all_mses:
        print("\nTest Results:")
        print(f"Average MSE across {len(all_mses)} samples: {np.mean(all_mses):.6f}")
        print(f"Min MSE: {np.min(all_mses):.6f}")
        print(f"Max MSE: {np.max(all_mses):.6f}")
        print(f"Median MSE: {np.median(all_mses):.6f}")
        print(f"Standard Deviation: {np.std(all_mses):.6f}")
    else:
        print("No valid predictions were made.")


if __name__ == "__main__":
    main()