import dataclasses
from typing import ClassVar

import einops
import numpy as np

from openpi import transforms


def _parse_image(image) -> np.ndarray:
    image = np.asarray(image)
    if np.issubdtype(image.dtype, np.floating):
        image = (255 * image).astype(np.uint8)
    if image.shape[0] == 3:
        image = einops.rearrange(image, "c h w -> h w c")
    return image


@dataclasses.dataclass(frozen=True)
class CalvinInputs(transforms.DataTransformFn):
    """Inputs for the Aloha policy.

    Expected inputs:
    - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
    - state: [14]
    - actions: [action_horizon, 7]
    """

    # The action dimension of the model. Will be used to pad state and actions.
    action_dim: int

    # If true, this will convert the joint and gripper values from the standard Aloha space to
    # the space used by the pi internal runtime which was used to train the base model.
    adapt_to_pi: bool = True

    # The expected cameras names. All input cameras must be in this set. Missing cameras will be
    # replaced with black images and the corresponding `image_mask` will be set to False.
    EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_head", "cam_right_wrist", "next_cam_head", "next_cam_right_wrist")

    def __call__(self, data: dict) -> dict:
        data = _decode_calvin(data)

        # Get the state. We are padding from 14 to the model action dim.
        state = transforms.pad_to_dim(data["state"], self.action_dim)

        in_images = data["images"]
        if set(in_images) - set(self.EXPECTED_CAMERAS):
            raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")

        # Assume that base image always exists.
        base_image = in_images["cam_head"]

        image_dict = {
            "base_0_rgb": base_image,
        }
        image_mask_dict = {
            "base_0_rgb": np.True_,
        }

        # Add the extra images.
        extra_image_names = {
            "left_wrist_0_rgb": "cam_left_wrist",
            "right_wrist_0_rgb": "cam_right_wrist",
        }
        for dest, source in extra_image_names.items():
            if source in in_images:
                image_dict[dest] = in_images[source]
                image_mask_dict[dest] = np.True_
            else:
                image_dict[dest] = np.zeros_like(base_image)
                image_mask_dict[dest] = np.False_
        if "next_cam_head" in data["images"]:
            next_base_image = in_images["next_cam_head"]

            image_dict["next_base_0_rgb"] = next_base_image
            image_mask_dict["next_base_0_rgb"] = np.True_

            # Add the extra images.
            extra_image_names = {
                "next_left_wrist_0_rgb": "next_cam_left_wrist",
                "next_right_wrist_0_rgb": "next_cam_right_wrist",
            }
            for dest, source in extra_image_names.items():
                if source in in_images:
                    image_dict[dest] = in_images[source]
                    image_mask_dict[dest] = np.True_
                else:
                    image_dict[dest] = np.zeros_like(base_image)
                    image_mask_dict[dest] = np.False_

        inputs = {
            "image": image_dict,
            "image_mask": image_mask_dict,
            "state": state,
        }

        # Actions are only available during training.
        if "actions" in data:
            # We are padding from 7 to the model action dim.
            # For pi0-FAST, this is a no-op (since action_dim = 7).
            actions = transforms.pad_to_dim(data["actions"], self.action_dim)
            inputs["actions"] = actions

        if "prompt" in data:
            inputs["prompt"] = data["prompt"]

        return inputs


@dataclasses.dataclass(frozen=True)
class CalvinOutputs(transforms.DataTransformFn):
    """Outputs for the Aloha policy."""

    # If true, this will convert the joint and gripper values from the standard Aloha space to
    # the space used by the pi internal runtime which was used to train the base model.
    adapt_to_pi: bool = True

    def __call__(self, data: dict) -> dict:
        # Only return the first 7 dims.
        return {"actions": np.asarray(data["actions"][:, :7])}

def _decode_calvin(data: dict) -> dict:
    # state is [left_arm_joint_angles, right_arm_joint_angles, left_arm_gripper, right_arm_gripper]
    # dim sizes: [6, 1, 6, 1]
    state = np.asarray(data["state"])
    # state = _decode_state(state, adapt_to_pi=adapt_to_pi)

    def convert_image(img):
        img = np.asarray(img)

        if np.issubdtype(img.dtype, np.floating):
            img = (255 * img).astype(np.uint8)

        return einops.rearrange(img, "c h w -> h w c")

    images = data["images"]
    images_dict = {name: convert_image(img) for name, img in images.items()}
    data["images"] = images_dict
    data["state"] = state
    return data
