import dataclasses

import einops
import numpy as np

from openpi import transforms
from openpi.models import model as _model
from scipy.spatial.transform import Rotation as R


def make_dvrk_example() -> dict:
    """Creates a random input example for the Aloha policy."""
    return {
        "state": np.zeros((16,)),
        "images": {
            "left": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
            "right": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
            "endo_psm1": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
            "endo_psm2": np.random.randint(256, size=(3, 224, 224), dtype=np.uint8),
        },
        "prompt": "do something",
    }

def _parse_image(image) -> np.ndarray:
    image = np.asarray(image)
    if np.issubdtype(image.dtype, np.floating):
        image = (255 * image).astype(np.uint8)
    if image.shape[0] == 3:
        image = einops.rearrange(image, "c h w -> h w c")
    return image


@dataclasses.dataclass(frozen=True)
class DvrkInputs(transforms.DataTransformFn):
    # The action dimension of the model. Will be used to pad state and actions for pi0 model (not pi0-FAST).
    action_dim: int

    # Determines which model will be used.
    model_type: _model.ModelType = _model.ModelType.PI0

    targeting_strategy: str = "none"

    def __call__(self, data: dict) -> dict:
        """
        Expected inputs:
        - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
        - state: [14]
        - actions: [action_horizon, 14]
        """
        mask_padding = self.model_type == _model.ModelType.PI0  # We don't mask for pi0-FAST.

        # Get the state. We are padding from 8 to the model action dim.
        # For pi0-FAST, we don't pad the state (action_dim = 7, which is < 8, so pad is skipped).
        state = transforms.pad_to_dim(data["state"], self.action_dim)

        add_targeting = self.targeting_strategy in ["mask", "heatmap"]

        # Possibly need to parse images to uint8 (H,W,C) since LeRobot automatically
        # stores as float32 (C,H,W), gets skipped for policy inference
        base_image = _parse_image(data["left_image"])
        left_wrist_image = _parse_image(data["endo_psm1_image"])
        right_wrist_image = _parse_image(data["endo_psm2_image"])
        if add_targeting:
            targeting_image = _parse_image(data["targeting_image"])


        inputs = {
            "state": state,
            "image": {
                "base_0_rgb": base_image,
                "left_wrist_0_rgb": left_wrist_image,
                "right_wrist_0_rgb": right_wrist_image,
            },
            "image_mask": {
                "base_0_rgb": np.True_,
                "left_wrist_0_rgb": np.True_,
                "right_wrist_0_rgb": np.True_,
            }
        }

        if add_targeting:
            inputs["image"]["targeting_0_rgb"] = targeting_image
            inputs["image_mask"]["targeting_0_rgb"] = np.True_

        # Actions are only available during training.
        if 'actions' in data:
            inputs["actions_is_pad"] = data["actions_is_pad"]
            inputs["actions"] = transforms.pad_to_dim(data["actions"], self.action_dim)
            
        if "prompt" in data:
            inputs["prompt"] = data["prompt"]

        return inputs


@dataclasses.dataclass(frozen=True)
class DvrkOutputs(transforms.DataTransformFn):
    """Outputs for the Aloha policy."""

    def __call__(self, data: dict) -> dict:
        # Only return the first 14 dims.
        actions = np.asarray(data["actions"][:, :14])
        return {"actions": actions}

######### MARK: - Helper functions

def _joint_flip_mask() -> np.ndarray:
    """Used to convert between aloha and pi joint angles."""
    return np.array([1, -1, -1, 1, 1, 1, 1, 1, -1, -1, 1, 1, 1, 1])


def _normalize(x, min_val, max_val):
    return (x - min_val) / (max_val - min_val)


def _unnormalize(x, min_val, max_val):
    return x * (max_val - min_val) + min_val

