
import jax
import jax.numpy as jnp

import jax.lax


def random_flip(images, labels, key):
    """Randomly flips images horizontally."""
    flip_mask = jax.random.bernoulli(key, p=0.5, shape=(images.shape[0],))
    flipped_images = jnp.where(flip_mask[:, None, None, None], images[:, :, ::-1, :], images)
    return flipped_images, labels


def random_crop(images, labels, key, pad=4):
    """Randomly crops a batch of images with padding."""
    batch_size, h, w, c = images.shape
    padded = jnp.pad(images, ((0, 0), (pad, pad), (pad, pad), (0, 0)), mode='reflect')

    key_x, key_y = jax.random.split(key)
    crop_x = jax.random.randint(key_x, (batch_size,), 0, 2 * pad)
    crop_y = jax.random.randint(key_y, (batch_size,), 0, 2 * pad)

    def crop(img, x, y):
        return jax.lax.dynamic_slice(img, (y, x, 0), (h, w, c))

    cropped_images = jax.vmap(crop)(padded, crop_x, crop_y)
    return cropped_images, labels


def mixup(images, labels, key, alpha=0.2):
    """Applies MixUp augmentation."""
    batch_size = images.shape[0]
    lam = jax.random.beta(key, alpha, alpha, (batch_size, 1, 1, 1))
    
    indices = jax.random.permutation(key, batch_size)
    mixed_images = lam * images + (1 - lam) * images[indices]
    mixed_labels = lam[:, 0, 0, 0] * labels + (1 - lam[:, 0, 0, 0]) * labels[indices]
    
    return mixed_images, mixed_labels


def cutmix(images, labels, key, alpha=1.0):
    """Applies CutMix augmentation to a batch of images and labels."""
    batch_size, h, w, c = images.shape
    key_lam, key_x, key_y, key_perm = jax.random.split(key, 4)

    lam = jax.random.beta(key_lam, alpha, alpha, (batch_size,))  # Per-sample lambda

    cut_rat = jnp.sqrt(1.0 - lam)
    cut_w = jnp.minimum((w * cut_rat).astype(int), w)  # Ensure within bounds
    cut_h = jnp.minimum((h * cut_rat).astype(int), h)  # Ensure within bounds

    cx = jax.random.randint(key_x, (batch_size,), 0, w)
    cy = jax.random.randint(key_y, (batch_size,), 0, h)

    x1 = jnp.clip(cx - cut_w // 2, 0, w - 1)
    y1 = jnp.clip(cy - cut_h // 2, 0, h - 1)

    indices = jax.random.permutation(key_perm, batch_size)

    def apply_cutmix(img1, img2, x1, y1, cut_w, cut_h):
        """Replaces a patch in `img1` with a patch from `img2`."""
        slice_w = jnp.minimum(cut_w, w - x1)  # Ensure slice is within bounds
        slice_h = jnp.minimum(cut_h, h - y1)  # Ensure slice is within bounds

        patch = jax.lax.dynamic_slice(img2, (y1, x1, 0), (slice_h, slice_w, c))
        return jax.lax.dynamic_update_slice(img1, patch, (y1, x1, 0))

    # Apply CutMix using vmap
    mixed_images = jax.vmap(apply_cutmix)(images, images[indices], x1, y1, cut_w, cut_h)
    mixed_labels = lam[:, None] * labels + (1 - lam[:, None]) * labels[indices]

    return mixed_images, mixed_labels



def random_flip_horizontal(rng, images):
    """Randomly flip images horizontally."""
    flip_rng, rng = jax.random.split(rng)
    mask = jax.random.bernoulli(flip_rng, p=0.5, shape=(images.shape[0], 1, 1, 1))
    return jnp.where(mask, jnp.flip(images, axis=2), images), rng

def random_translate(rng, images, max_shift=4):
    """Randomly translate images using JAX's lax.dynamic_slice."""
    translate_rng, rng = jax.random.split(rng)
    batch_size, height, width, channels = images.shape
    
    # Sample dx and dy shifts
    dx = jax.random.randint(translate_rng, (batch_size,), -max_shift, max_shift + 1)
    dy = jax.random.randint(translate_rng, (batch_size,), -max_shift, max_shift + 1)

    # Pad the image to allow shifts
    pad_size = max_shift
    padded_images = jnp.pad(images, ((0, 0), (pad_size, pad_size), (pad_size, pad_size), (0, 0)), mode='reflect')

    def translate(img, dx, dy):
        """Apply translation using JAX's dynamic slicing."""
        return jax.lax.dynamic_slice(
            img, (dx + pad_size, dy + pad_size, 0), (height, width, channels)
        )

    # Vectorized translation using vmap
    translated_images = jax.vmap(translate)(padded_images, dx, dy)
    
    return translated_images, rng

def random_brightness(rng, images, max_delta=0.2):
    """Randomly adjust brightness."""
    brightness_rng, rng = jax.random.split(rng)
    delta = jax.random.uniform(brightness_rng, (images.shape[0], 1, 1, 1), minval=-max_delta, maxval=max_delta)
    return jnp.clip(images + delta, 0.0, 1.0), rng

def random_contrast(rng, images, lower=0.8, upper=1.2):
    """Randomly adjust contrast."""
    contrast_rng, rng = jax.random.split(rng)
    factors = jax.random.uniform(contrast_rng, (images.shape[0], 1, 1, 1), minval=lower, maxval=upper)
    mean = jnp.mean(images, axis=(1, 2), keepdims=True)
    return jnp.clip((images - mean) * factors + mean, 0.0, 1.0), rng


def random_cutout(rng, images, size=8):
    """Apply cutout by setting a random square region to zero using JAX's lax.dynamic_update_slice."""
    cutout_rng, rng = jax.random.split(rng)
    batch_size, height, width, channels = images.shape

    # Sample random center points for the cutout region
    cx = jax.random.randint(cutout_rng, (batch_size,), 0, height)
    cy = jax.random.randint(cutout_rng, (batch_size,), 0, width)

    # Cutout mask of the same shape as the image
    def apply_cutout(img, cx, cy):
        """Apply cutout by setting a square region to zero."""
        mask = jnp.ones_like(img)

        # Ensure valid slice locations using clipping
        x1 = jnp.clip(cx - size // 2, 0, height - size)
        y1 = jnp.clip(cy - size // 2, 0, width - size)

        cutout_patch = jnp.zeros((size, size, channels), dtype=img.dtype)
        mask =jax.lax.dynamic_update_slice(mask, cutout_patch, (x1, y1, 0))

        return img * mask

    # Apply cutout across the batch using vmap
    images = jax.vmap(apply_cutout)(images, cx, cy)
    
    return images, rng

def autoaugment_cifar10(images, labels, rng):
    """
    Applies a set of JAX-native augmentations to a batch of CIFAR-10 images.
    
    Args:
        rng: JAX PRNG key.
        images: jnp.ndarray of shape (batch_size, height, width, channels).
        labels: jnp.ndarray of shape (batch_size,).
    
    Returns:
        Augmented images and unchanged labels.
    """
    images, rng = random_flip_horizontal(rng, images)
    images, rng = random_translate(rng, images, max_shift=4)
    images, rng = random_brightness(rng, images, max_delta=0.2)
    images, rng = random_contrast(rng, images, lower=0.8, upper=1.2)
    images, rng = random_cutout(rng, images, size=8)
    return images, labels


def normalize_images(images: jnp.ndarray, mean: tuple[float, float, float], std: tuple[float, float, float]) -> jnp.ndarray:
    """
    Normalize images using the given per-channel mean and standard deviation.
    
    Args:
        images: jnp.ndarray of shape (batch_size, height, width, channels), dtype=jnp.float32
        mean: Tuple of means for each channel (R, G, B).
        std: Tuple of standard deviations for each channel (R, G, B).

    Returns:
        Normalized images as a JAX array.
    """
    mean = jnp.array(mean).reshape(1, 1, 1, 3)  # Reshape for broadcasting
    std = jnp.array(std).reshape(1, 1, 1, 3)  # Reshape for broadcasting

    return (images - mean) / std