import dataclasses
import dataclasses
from typing import ClassVar

import einops
import numpy as np

from openpi import transforms


def _parse_image(image) -> np.ndarray:
    image = np.asarray(image)
    if np.issubdtype(image.dtype, np.floating):
        image = (255 * image).astype(np.uint8)
    if image.shape[0] == 3:
        image = einops.rearrange(image, "c h w -> h w c")
    return image

def _decode_r1(data: dict) -> dict:
    state = np.asarray(data["state"])

    images = data["images"]
    images_dict = {name: _parse_image(img) for name, img in images.items()}

    data["images"] = images_dict
    data["state"] = state

    return data


@dataclasses.dataclass(frozen=True)
class R1Inputs(transforms.DataTransformFn):
    """Inputs for the R1 policy.

    Expected inputs:
    - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
    - state: [14]
    - actions: [action_horizon, 14]
    """

    # The action dimension of the model. Will be used to pad state and actions.
    action_dim: int


    # The expected cameras names. All input cameras must be in this set. Missing cameras will be
    # replaced with black images and the corresponding `image_mask` will be set to False.
    EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_head", "cam_left_wrist", "cam_right_wrist")

    def __call__(self, data: dict) -> dict:
        data = _decode_r1(data)
        # Get the state. We are padding from 14 to the model action dim.
        state = transforms.pad_to_dim(data["state"], self.action_dim)

        in_images = data["images"]
        # Assume that base image always exists.
        base_image = in_images["cam_head"]

        if "next_state" in data:
            # Handle next state data
            next_state = transforms.pad_to_dim(data["next_state"], self.action_dim)
            
            images = {
                "base_0_rgb": base_image,
                "left_wrist_0_rgb": in_images.get("cam_left_wrist", np.zeros_like(base_image)),
                "right_wrist_0_rgb": in_images.get("cam_right_wrist", np.zeros_like(base_image)),
                "next_base_0_rgb": in_images.get("next_cam_head", in_images.get("cam_head", np.zeros_like(base_image))),
                "next_left_wrist_0_rgb": in_images.get("next_cam_left_wrist", np.zeros_like(base_image)),
                "next_right_wrist_0_rgb": in_images.get("next_cam_right_wrist", np.zeros_like(base_image)),
            }
            
            image_masks = {
                "base_0_rgb": np.True_,
                "left_wrist_0_rgb": np.True_ if "cam_left_wrist" in in_images else np.False_,
                "right_wrist_0_rgb": np.True_ if "cam_right_wrist" in in_images else np.False_,
                "next_base_0_rgb": np.True_ if "next_cam_head" in in_images or "cam_head" in in_images else np.False_,
                "next_left_wrist_0_rgb": np.True_ if "next_cam_left_wrist" in in_images else np.False_,
                "next_right_wrist_0_rgb": np.True_ if "next_cam_right_wrist" in in_images else np.False_,
            }
            
            inputs = {
                "image": images,
                "image_mask": image_masks,
                "state": state,
                "next_state": next_state
            }
        else:
            # Handle standard data without next state
            images = {
                "base_0_rgb": base_image,
            }
            image_masks = {
                "base_0_rgb": np.True_,
            }

            # Add the extra images.
            extra_image_names = {
                "left_wrist_0_rgb": "cam_left_wrist",
                "right_wrist_0_rgb": "cam_right_wrist",
            }
            for dest, source in extra_image_names.items():
                if source in in_images:
                    images[dest] = in_images[source]
                    image_masks[dest] = np.True_
                else:
                    images[dest] = np.zeros_like(base_image)
                    image_masks[dest] = np.False_

            inputs = {
                "image": images,
                "image_mask": image_masks,
                "state": state,
            }

        # Actions are only available during training.
        if "actions" in data:
            # We are padding from 7 to the model action dim.
            # For pi0-FAST, this is a no-op (since action_dim = 7).
            actions = transforms.pad_to_dim(data["actions"], self.action_dim)
            inputs["actions"] = actions

        if "next_actions" in data:
            next_actions = transforms.pad_to_dim(data["next_actions"], self.action_dim)
            inputs["next_actions"] = next_actions    
        
        if "rewards" in data:
            inputs["rewards"] = data["rewards"]

        if "terminal" in data:
            inputs["terminal"] = data["terminal"]

        if "prompt" in data:
            inputs["prompt"] = data["prompt"]

        return inputs


@dataclasses.dataclass(frozen=True)
class R1Outputs(transforms.DataTransformFn):
    """Outputs for the R1 policy."""

    def __call__(self, data: dict) -> dict:
        # Only return the first 7 dims.
        return {"actions": np.asarray(data["actions"][:, :21])}

@dataclasses.dataclass(frozen=True)
class R1SingleArmInputs(transforms.DataTransformFn):
    """Inputs for the R1 policy.

    Expected inputs:
    - images: dict[name, img] where img is [channel, height, width]. name must be in EXPECTED_CAMERAS.
    - state: [14]
    - actions: [action_horizon, 14]
    """

    # The action dimension of the model. Will be used to pad state and actions.
    action_dim: int


    # The expected cameras names. All input cameras must be in this set. Missing cameras will be
    # replaced with black images and the corresponding `image_mask` will be set to False.
    EXPECTED_CAMERAS: ClassVar[tuple[str, ...]] = ("cam_head", "cam_right_wrist")

    def __call__(self, data: dict) -> dict:
        data = _decode_r1(data)
        # Get the state. We are padding from 14 to the model action dim.
        state = transforms.pad_to_dim(data["state"], self.action_dim)

        in_images = data["images"]
        if set(in_images) - set(self.EXPECTED_CAMERAS):
            raise ValueError(f"Expected images to contain {self.EXPECTED_CAMERAS}, got {tuple(in_images)}")

        # Assume that base image always exists.
        base_image = in_images["cam_head"]

        image_dict = {
            "base_0_rgb": base_image,
        }
        image_mask_dict = {
            "base_0_rgb": np.True_,
        }

        # Add the extra images.
        extra_image_names = {
            "left_wrist_0_rgb": "cam_left_wrist",
            "right_wrist_0_rgb": "cam_right_wrist",
        }
        for dest, source in extra_image_names.items():
            if source in in_images:
                image_dict[dest] = in_images[source]
                image_mask_dict[dest] = np.True_
            else:
                image_dict[dest] = np.zeros_like(base_image)
                image_mask_dict[dest] = np.False_

        inputs = {
            "image": image_dict,
            "image_mask": image_mask_dict,
            "state": state,
        }

        # Actions are only available during training.
        if "actions" in data:
            # We are padding from 7 to the model action dim.
            # For pi0-FAST, this is a no-op (since action_dim = 7).
            actions = transforms.pad_to_dim(data["actions"], self.action_dim)
            inputs["actions"] = actions

        if "prompt" in data:
            inputs["prompt"] = data["prompt"]

        return inputs


@dataclasses.dataclass(frozen=True)
class R1SingleArmOutputs(transforms.DataTransformFn):
    """Outputs for the R1 policy."""

    def __call__(self, data: dict) -> dict:
        # Only return the first 7 dims.
        return {"actions": np.asarray(data["actions"][:, :7])}

