import cv2
import torch
import numpy as np
from PIL import Image

# Function to create a Color Cube
def create_color_cube(image):
    """
    Convert an RGB image to HSV and YCbCr color spaces and concatenate all channels to create a Color Cube.
    Args:
        image: Input RGB image as a NumPy array of shape (H, W, 3).
    Returns:
        color_cube: Concatenated image as a NumPy array of shape (H, W, 9).
    """
    # Convert RGB to HSV and YCbCr
    hsv_image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
    ycbcr_image = cv2.cvtColor(image, cv2.COLOR_RGB2YCrCb)
    
    # Concatenate the RGB, HSV, and YCbCr channels to create a Color Cube
    color_cube = np.concatenate((image, hsv_image, ycbcr_image), axis=-1)
    
    return color_cube

# Custom Transform Class for Color Cube
class ColorCubeTransform:
    def __call__(self, img):
        """
        Convert an image to the Color Cube format using RGB, HSV, and YCbCr color spaces.
        Args:
            img: Input image (PIL or NumPy format).
        Returns:
            A torch tensor with 9 channels.
        """
        img = np.array(img)  # Convert PIL image to NumPy array
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)  # Convert RGB (PIL) to BGR (OpenCV uses BGR by default)
        color_cube = create_color_cube(img)
        color_cube = torch.tensor(color_cube).permute(2, 0, 1).float() / 255.0  # Convert to tensor and normalize
        return color_cube

# Optional additional functions for other color space conversions can be placed here:
def lab(img):
    """
    Convert image to LAB color space.
    """
    y = cv2.cvtColor(np.float32(img) / 255, cv2.COLOR_RGB2LAB)
    y[:, :, 0] = y[:, :, 0] * 255 / 100
    y[:, :, 1:] = y[:, :, 1:] + 128
    y1 = y.astype(np.uint8)
    return Image.fromarray(y1)

def ycbcr(img):
    """
    Convert image to YCbCr color space.
    """
    y = cv2.cvtColor(np.float32(img), cv2.COLOR_RGB2YCR_CB)
    y1 = y.astype(np.uint8)
    return Image.fromarray(y1)

def hsv(img):
    """
    Convert image to HSV color space.
    """
    y = cv2.cvtColor(np.float32(img) / 255, cv2.COLOR_RGB2HSV)
    y1 = y * np.array([0.5, 255, 255])
    y1 = y1.astype(np.uint8)
    return Image.fromarray(y1)
