import cv2
import numpy as np

def apply_clahe(image, clip_limit, tile_grid_size):
    lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
    l, a, b = cv2.split(lab)
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    l_clahe = clahe.apply(l)
    lab_clahe = cv2.merge((l_clahe, a, b))
    return cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2RGB)

def apply_rgb_shift(image, r_shift, g_shift, b_shift):
    image = image.astype(np.int16)
    image[:, :, 0] = np.clip(image[:, :, 0] + r_shift, 0, 255)
    image[:, :, 1] = np.clip(image[:, :, 1] + g_shift, 0, 255)
    image[:, :, 2] = np.clip(image[:, :, 2] + b_shift, 0, 255)
    return image.astype(np.uint8)

def apply_hsv_shift(image, h_shift, s_shift, v_shift):
    hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV).astype(np.int32)
    hsv[:, :, 0] = (hsv[:, :, 0] + h_shift) % 180
    hsv[:, :, 1] = np.clip(hsv[:, :, 1] + s_shift, 0, 255)
    hsv[:, :, 2] = np.clip(hsv[:, :, 2] + v_shift, 0, 255)
    return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)

def apply_unsharp_mask_with_strength(image, edge_strength):
    if edge_strength <= 0:
        return image
    alpha = 1.0 + 1.5 * edge_strength
    sigma = 1.0 + 4.0 * edge_strength
    ksize = int(3 + 4 * edge_strength)
    ksize = ksize + 1 if ksize % 2 == 0 else ksize
    blurred = cv2.GaussianBlur(image, (ksize, ksize), sigma)
    sharpened = cv2.addWeighted(image, alpha, blurred, -alpha + 1, 0)
    return np.clip(sharpened, 0, 255).astype(np.uint8)

def image_augmentation(image_path, hp):
    """
    Apply a series of image augmentations defined in the hyperparameter dictionary.
    
    Args:
        image_path (str): Path to the image file.
        hp (dict): Dictionary containing augmentation parameters.

    Returns:
        np.ndarray: Augmented RGB image.
    """
    image = cv2.imread(image_path)
    if image is None:
        raise ValueError(f"Cannot load image from {image_path}")
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Apply HSV shift
    if any(hp.get(k, 0) != 0 for k in ["hsv_hue_shift", "hsv_sat_shift", "hsv_val_shift"]):
        image = apply_hsv_shift(
            image,
            h_shift=hp.get("hsv_hue_shift", 0),
            s_shift=hp.get("hsv_sat_shift", 0),
            v_shift=hp.get("hsv_val_shift", 0)
        )

    # Apply RGB shift
    if any(hp.get(k, 0) != 0 for k in ["r_shift", "g_shift", "b_shift"]):
        image = apply_rgb_shift(
            image,
            r_shift=hp.get("r_shift", 0),
            g_shift=hp.get("g_shift", 0),
            b_shift=hp.get("b_shift", 0)
        )

    # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
    if hp.get("clahe_clip", 0) > 0:
        grid_size = 4 * hp.get("clahe_grid", 8)
        image = apply_clahe(
            image,
            clip_limit=hp["clahe_clip"],
            tile_grid_size=(grid_size, grid_size)
        )

    # Apply unsharp masking
    image = apply_unsharp_mask_with_strength(image, edge_strength=hp.get("edge_strength", 0.0))

    return image

def select_hp(hp_all, task):
    """
    Extract sub-dictionary of hyperparameters for grounding or segmentation.

    Args:
        hp_all (dict): Dictionary containing all hyperparameters.
        task (str): Either 'grd' or 'seg'.

    Returns:
        dict: Subset of hyperparameters specific to the task.
    """
    hp = {}
    for k, v in hp_all.items():
        if k.startswith(task):
            k_new = k[len(task)+1:]
            hp[k_new] = hp_all[k]
    return hp
