"""
eval_with_deploy.py

Evaluate pretrained OpenVLA by calling the deployed server API.
This ensures we use the exact same data preprocessing as in deployment.
"""

# ruff: noqa: E402
import json_numpy

json_numpy.patch()


import json
import pickle
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import requests
import torch
from tqdm import tqdm

from prismatic.vla.constants import ACTION_DIM
from prismatic.vla.datasets.rlds.oxe.mixtures import OXE_NAMED_MIXTURES


class DeployedOpenVLAEvaluator:
    def __init__(self, server_url: str = "http://0.0.0.0:8000"):
        """Initialize evaluator with deployed OpenVLA server URL."""
        self.server_url = server_url
        self.act_endpoint = f"{server_url}/act"

        # Test connection
        try:
            requests.get(server_url.replace("/act", ""))
        except Exception:
            print(f"Warning: Could not connect to server at {server_url}")
            print("Make sure the server is running with: python vla-scripts/deploy.py")

    def call_api(self, image: np.ndarray, instruction: str, unnorm_key: Optional[str] = None) -> np.ndarray:
        """Call the deployed OpenVLA API to get action prediction."""
        payload = {
            "image": image,
            "instruction": instruction,
        }
        if unnorm_key is not None:
            payload["unnorm_key"] = unnorm_key

        try:
            response = requests.post(self.act_endpoint, json=payload)

            if response.status_code == 200:
                action = response.json()
                return action
            else:
                print(f"API error: {response.status_code}")
                return None
        except Exception as e:
            print(f"Error calling API: {e}")
            return None


def get_dataset_and_statistics(dataset_name: str, data_root_dir_seen: Path) -> Dict:
    """Get dataset statistics for unnormalization."""
    from transformers import AutoProcessor

    from prismatic.models.backbones.llm.prompting import PurePromptBuilder
    from prismatic.vla.materialize import get_vla_dataset_and_collator

    seen_ds = [d[0] for d in OXE_NAMED_MIXTURES["oxe_magic_soup_plus"]]
    data_root = data_root_dir_seen if dataset_name in seen_ds else "gs://gresearch/robotics/"

    # Create a dummy processor for data loading
    processor = AutoProcessor.from_pretrained("openvla/openvla-7b", trust_remote_code=True)

    # Load dataset to get statistics
    dataset, _action_tokenizer, _collator = get_vla_dataset_and_collator(
        data_root_dir=data_root,
        data_mix=dataset_name,
        image_transform=processor.image_processor.apply_transform,
        tokenizer=processor.tokenizer,
        prompt_builder_fn=PurePromptBuilder,
        default_image_resolution=(3, 224, 224),
        shuffle_buffer_size=1000,
        train=False,
        episodic=True,
        image_aug=False,
    )

    # Extract statistics
    stats = dataset.dataset_statistics
    return dataset, stats


@torch.no_grad()
def evaluate_trajectory_with_deploy(
    evaluator: DeployedOpenVLAEvaluator,
    trajectory: List[Dict],
    eval_stats: Dict,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
    """Evaluate a single trajectory using deployed API."""
    all_pred_norm = []
    all_pred_unnorm = []
    all_gt_norm = []
    all_gt_unnorm = []

    for i in range(trajectory["action"].shape[0]):
        image = trajectory["observation"]["image_primary"][i][0]
        instruction = trajectory["task"]["language_instruction"][0].decode()
        gt_action_norm = trajectory["action"][i][0]

        # Call API for prediction
        pred_action = evaluator.call_api(image, instruction, unnorm_key="dummy")

        assert pred_action is not None, "Failed to get prediction for step"

        # Normalize ground truth action for comparison
        # Using eval dataset statistics
        stats_action = eval_stats["action"]
        high = stats_action["q99"]
        low = stats_action["q01"]
        mask = stats_action["mask"]

        # Unnormalize predicted action
        pred_unnorm = np.where(mask, 0.5 * (pred_action + 1.0) * (high - low) + low, pred_action)

        # Normalize GT to [-1, 1]
        gt_unnorm = np.where(mask, 0.5 * (gt_action_norm + 1.0) * (high - low) + low, gt_action_norm)

        # Both are normalized
        pred_norm = pred_action
        gt_norm = gt_action_norm
        pred_unnorm = pred_unnorm
        gt_unnorm = gt_unnorm

        all_pred_norm.append(pred_norm)
        all_pred_unnorm.append(pred_unnorm)
        all_gt_norm.append(gt_norm)
        all_gt_unnorm.append(gt_unnorm)

    return (np.stack(all_pred_norm), np.stack(all_pred_unnorm), np.stack(all_gt_norm), np.stack(all_gt_unnorm))


def evaluate_dataset_with_deploy(
    server_url: str,
    dataset_name: str,
    data_root_dir_seen: Path,
    num_episodes: int = 20,
) -> Dict:
    """Evaluate a dataset using deployed OpenVLA server."""
    evaluator = DeployedOpenVLAEvaluator(server_url)

    dataset, eval_stats = get_dataset_and_statistics(dataset_name, data_root_dir_seen)

    episodes = dataset.dataset

    # Evaluate trajectories
    results = {}
    for i, trajectory in enumerate(tqdm(episodes.as_numpy_iterator(), desc="Evaluating episodes", total=num_episodes)):
        if i >= num_episodes:
            break
        pred_norm, pred_unnorm, gt_norm, gt_unnorm = evaluate_trajectory_with_deploy(evaluator, trajectory, eval_stats)

        if pred_norm is not None:
            results[i] = {
                "pred_norm": pred_norm,
                "pred_unnorm": pred_unnorm,
                "gt_norm": gt_norm,
                "gt_unnorm": gt_unnorm,
            }

    return results


def main():
    import argparse

    parser = argparse.ArgumentParser(description="Evaluate OpenVLA via deployed server API")
    parser.add_argument("--server_url", type=str, default="http://0.0.0.0:8000", help="URL of deployed OpenVLA server")
    parser.add_argument(
        "--data_root_dir_seen",
        type=str,
        default="DATA_ROOT_DIR",
        help="Root dir for seen datasets",
    )
    parser.add_argument("--dataset_names", type=str, nargs="+", default=None, help="Datasets to evaluate")
    parser.add_argument("--num_episodes", type=int, default=20, help="Number of episodes per dataset")
    parser.add_argument("--output_dir", type=str, default="eval-results/validation/deployed_openvla")

    args = parser.parse_args()

    # Dataset list default
    if args.dataset_names is None:
        seen = [d[0] for d in OXE_NAMED_MIXTURES["oxe_magic_soup_plus"]]
        unseen = []
        for dname, _ in OXE_NAMED_MIXTURES["rtx_franka"]:
            if dname not in seen:
                unseen.append(dname)
        dataset_names = seen + unseen

    else:
        dataset_names = args.dataset_names

    # Filter out problematic datasets
    dataset_names = [ds for ds in dataset_names if ds != "berkeley_rpt_converted_externally_to_rlds"]

    out_dir = Path(args.output_dir)
    out_dir.mkdir(parents=True, exist_ok=True)

    print(f"Evaluating {len(dataset_names)} datasets via API at {args.server_url}")

    for ds in tqdm(dataset_names, desc="Datasets"):
        print("\n" + "=" * 50)
        print(f"Evaluating dataset: {ds}")
        print("=" * 50)

        # Skip if result file already exists
        if (out_dir / f"{ds}.pkl").exists():
            print(f"Skipping {ds} because result file already exists")
            continue

        results = evaluate_dataset_with_deploy(
            server_url=args.server_url,
            dataset_name=ds,
            data_root_dir_seen=Path(args.data_root_dir_seen),
            num_episodes=args.num_episodes,
        )

        with open(out_dir / f"{ds}.pkl", "wb") as f:
            pickle.dump(results, f)

        print(f"Saved: {out_dir / f'{ds}.pkl'}")

    print("\nEvaluation complete!")


if __name__ == "__main__":
    main()
