##### MARK: Inference Client #####

import numpy as np
import json
from typing import Dict, Any
import torch
from io import BytesIO
import zmq
import cv2

class TorchSerializer:
    @staticmethod
    def to_bytes(data: dict) -> bytes:
        buffer = BytesIO()
        torch.save(data, buffer)
        return buffer.getvalue()

    @staticmethod
    def from_bytes(data: bytes) -> dict:
        buffer = BytesIO(data)
        obj = torch.load(buffer, weights_only=False)
        return obj

class RobotInferenceClient:
    def __init__(self, host: str = "localhost", port: int = 5555, timeout_ms: int = 15000, 
                 stats_path: str = None):
        self.context = zmq.Context()
        self.host = host
        self.port = port
        self.timeout_ms = timeout_ms
        self.stats_path = stats_path
        self.stats = None
        assert stats_path is not None, "Stats path must be provided"
        self.stats = self.load_stats(stats_path)
        self._init_socket()

    def _init_socket(self):
        """Initialize or reinitialize the socket with current settings"""
        self.socket = self.context.socket(zmq.REQ)
        self.socket.connect(f"tcp://{self.host}:{self.port}")

    def load_stats(self, stats_path: str) -> dict:
        """
        Load statistics from JSON file for de-normalization.
        
        Args:
            stats_path: Path to the statistics JSON file
            
        Returns:
            Dictionary containing the statistics
        """
        try:
            with open(stats_path, 'r') as f:
                stats = json.load(f)
            return stats
        except Exception as e:
            print(f"Error loading statistics from {stats_path}: {e}")
            return None

    def denormalize_actions(self, action: Dict[str, np.ndarray]) -> Dict[str, np.ndarray]:
        """
        De-normalize actions using min-max scaling with quantiles.
        
        Args:
            action: Dictionary of normalized actions in the range [-1, 1]
            
        Returns:
            Dictionary of de-normalized actions
        """
        if not self.stats or "dataset_statistics" not in self.stats:
            print("Warning: No statistics loaded for denormalization")
            return action
        
        denormalized_action = {}
        
        # Get min/max values from quantiles (q01/q99)
        q01 = np.array(self.stats["dataset_statistics"]["q01"])
        q99 = np.array(self.stats["dataset_statistics"]["q99"])
        
        # De-normalize PSM1 actions using inverse of min-max scaling
        # original formula: normalized = 2 * (x - min) / (max - min) - 1
        # inverse: x = (normalized + 1) / 2 * (max - min) + min
        if "action.psm1" in action:
            action_dim = action["action.psm1"].shape[1]
            denorm_psm1 = np.zeros_like(action["action.psm1"])
            
            for i in range(action_dim):
                min_val = q01[i]
                max_val = q99[i]
                denorm_psm1[:, i] = (action["action.psm1"][:, i] + 1) / 2 * (max_val - min_val) + min_val
                
            denormalized_action["action.psm1"] = denorm_psm1
        
        # De-normalize PSM2 actions
        if "action.psm2" in action:
            action_dim = action["action.psm2"].shape[1]
            denorm_psm2 = np.zeros_like(action["action.psm2"])
            
            for i in range(action_dim):
                min_val = q01[i + action_dim]  # PSM2 stats are in the second half
                max_val = q99[i + action_dim]
                denorm_psm2[:, i] = (action["action.psm2"][:, i] + 1) / 2 * (max_val - min_val) + min_val
                
            denormalized_action["action.psm2"] = denorm_psm2
        
        return denormalized_action

    def ping(self) -> bool:
        try:
            self.call_endpoint("ping", requires_input=False)
            return True
        except zmq.error.ZMQError:
            self._init_socket()  # Recreate socket for next attempt
            return False

    def kill_server(self):
        """
        Kill the server.
        """
        self.call_endpoint("kill", requires_input=False)

    def get_action(self, observations: Dict[str, Any]) -> Dict[str, Any]:
        """
        Get normalized actions from the server and denormalize them if stats are available.
        
        Args:
            observations: Dictionary of observations
            
        Returns:
            Dictionary of actions (denormalized if stats are available)
        """
        # assert state is all zeros
        assert np.all(observations["state.psm1"] == 0), "PSM1 state must be all zeros"
        assert np.all(observations["state.psm2"] == 0), "PSM2 state must be all zeros"
        # Assert images have a batch dimension
        assert observations["video.main"].ndim == 4, "Main image must have a batch dimension"
        assert observations["video.endo_psm1"].ndim == 4, "PSM1 image must have a batch dimension"
        assert observations["video.endo_psm2"].ndim == 4, "PSM2 image must have a batch dimension"

        for key, value in observations.items():
            if key.startswith("video."):
                observations[key] = self.resize_with_padding(value)

        normalized_action = self.call_endpoint("get_action", observations)
        # denormalize actions with q01/q99
        denormalized_action = self.denormalize_actions(normalized_action)

        return denormalized_action
        
    def get_modality_config(self) -> Dict[str, Any]:
        return self.call_endpoint("get_modality_config", requires_input=False)

    def call_endpoint(
        self, endpoint: str, data: dict | None = None, requires_input: bool = True
    ) -> dict:
        """
        Call an endpoint on the server.

        Args:
            endpoint: The name of the endpoint.
            data: The input data for the endpoint.
            requires_input: Whether the endpoint requires input data.
        """
        request: dict = {"endpoint": endpoint}
        if requires_input:
            request["data"] = data

        self.socket.send(TorchSerializer.to_bytes(request))
        message = self.socket.recv()
        if message == b"ERROR":
            raise RuntimeError("Server error")
        return TorchSerializer.from_bytes(message)

    def __del__(self):
        """Cleanup resources on destruction"""
        self.socket.close()
        self.context.term()

    def resize_with_padding(image, target_width=224, target_height=224):
        """
        Resize an image to target dimensions while preserving aspect ratio using padding.
        
        Args:
            image: Input image
            target_width: Desired width
            target_height: Desired height
            
        Returns:
            Resized and padded image with dimensions (target_height, target_width)
        """
        # Get original dimensions
        h, w = image.shape[:2]
        
        # Calculate target aspect ratio and original aspect ratio
        target_aspect = target_width / target_height
        aspect = w / h
        
        # Determine new dimensions while preserving aspect ratio
        if aspect > target_aspect:
            # Image is wider than target aspect ratio
            new_w = target_width
            new_h = int(new_w / aspect)
        else:
            # Image is taller than target aspect ratio
            new_h = target_height
            new_w = int(new_h * aspect)
        
        # Resize image while preserving aspect ratio
        resized = cv2.resize(image, (new_w, new_h))
        
        # Create black canvas of target size
        padded = np.zeros((target_height, target_width, 3), dtype=np.uint8)
        
        # Calculate padding offsets to center the image
        pad_x = (target_width - new_w) // 2
        pad_y = (target_height - new_h) // 2
        
        # Place resized image on the canvas
        padded[pad_y:pad_y+new_h, pad_x:pad_x+new_w] = resized
        
        return padded