import numpy as np
from PIL import Image
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torchvision.transforms import InterpolationMode
from typing import Optional, Tuple, Dict, Any


def _to_pil(x):
    x = np.asarray(x)
    if x.ndim == 2:
        return Image.fromarray(x, mode="L")
    if x.ndim == 3 and x.shape[2] == 3:
        return Image.fromarray(x, mode="RGB")
    raise ValueError(f"Unsupported image shape: {x.shape}")


def _from_pil(img_pil, like_array):
    if like_array.ndim == 2:
        return np.array(img_pil.convert("L"), dtype=np.uint8)
    else:
        return np.array(img_pil.convert("RGB"), dtype=np.uint8)


def _alloc_like_batch(X, per_sample):
    X = np.asarray(X)
    N = len(X)
    if X.ndim == 3:
        H, W = X.shape[1:]
        X_out = np.empty((N * per_sample, H, W), dtype=np.uint8)
    else: 
        H, W = X.shape[1], X.shape[2]
        X_out = np.empty((N * per_sample, H, W, 3), dtype=np.uint8)
    return X_out


_GEOM_CANONICAL = ["Rotate", "Translate", "Scale"]


def _make_single_geom_transform(name, rng, params):
    angle_range: Tuple[float, float] = params.get("angle_range", (-15.0, 15.0))
    translate_frac: float = params.get("translate_frac", 1.0 / 32.0)
    scale_range: Tuple[float, float] = params.get("scale_range", (0.95, 1.05))
    interpolation = params.get("interpolation", InterpolationMode.BILINEAR)
    fill = params.get("fill", 0)

    if name == "Rotate":
        angle = float(rng.uniform(angle_range[0], angle_range[1]))
        def _f(img):
            return TF.rotate(
                img,
                angle=angle,
                interpolation=interpolation,
                fill=fill,
            )
        return T.Lambda(_f)
    if name == "Translate":
        def _f(img: Image.Image):
            w, h = img.size
            tx = int(np.round(rng.uniform(-translate_frac, translate_frac) * w))
            ty = int(np.round(rng.uniform(-translate_frac, translate_frac) * h))
            return TF.affine(
                img,
                angle=0.0,
                translate=(tx, ty),
                scale=1.0,
                shear=(0.0, 0.0),
                interpolation=interpolation,
                fill=fill,
            )
        return T.Lambda(_f)
    if name == "Scale":
        scale = float(rng.uniform(scale_range[0], scale_range[1]))
        def _f(img):
            return TF.affine(
                img,
                angle=0.0,
                translate=(0, 0),
                scale=scale,
                shear=(0.0, 0.0),
                interpolation=interpolation,
                fill=fill,
            )
        return T.Lambda(_f)
    raise ValueError(name)


def _make_geom_pipeline(rng, params: Dict[str, Any]):
    ops = []
    # Optional horizontal flip with given probability (decided once per sample)
    hflip_prob = float(params.get("hflip_prob", 0.0) or 0.0)
    if hflip_prob > 0.0:
        do_flip = bool(rng.rand() < hflip_prob)
        if do_flip:
            def _flip(img: Image.Image):
                return img.transpose(Image.FLIP_LEFT_RIGHT)
            ops.append(T.Lambda(_flip))
    ops.extend([_make_single_geom_transform(n, rng, params) for n in _GEOM_CANONICAL])
    return T.Compose(ops)


def geom_module(
    X,
    y,
    per_sample=1,
    *,
    angle_range: Tuple[float, float] = (-15.0, 15.0),
    translate_frac: float = 1.0 / 32.0,
    scale_range: Tuple[float, float] = (0.95, 1.05),
    interpolation=InterpolationMode.BILINEAR,
    fill=0,
    hflip_prob: Optional[float] = None,
    rng: Optional[np.random.RandomState] = None,
):
    X = np.asarray(X)
    y = np.asarray(y)
    N = len(X)

    if X.ndim == 4 and X.shape[-1] == 3 and (fill == 0 or fill is None):
        # Approximate ImageNet mean in 0-255 space
        fill = (123, 116, 103)
    if hflip_prob is None:
        if X.ndim == 4 and X.shape[-1] == 3:
            hflip_prob = 0.5
        else:
            hflip_prob = 0.0
    X_out = _alloc_like_batch(X, per_sample)
    y_out = np.empty((N * per_sample,), dtype=y.dtype)
    if rng is None:
        rng = np.random
    geom_params = {
        "angle_range": angle_range,
        "translate_frac": translate_frac,
        "scale_range": scale_range,
        "interpolation": interpolation,
        "fill": fill,
        "hflip_prob": float(hflip_prob),
    }
    k = 0
    for i in range(N):
        img_pil = _to_pil(X[i])
        yi = int(y[i])
        for rep in range(per_sample):
            geom_pipeline = _make_geom_pipeline(rng, geom_params)
            aug = geom_pipeline(img_pil)
            X_out[k] = _from_pil(aug, X[i])
            y_out[k] = yi
            k += 1
    return X_out, y_out


_OPT_CANONICAL = ["Brightness", "Contrast", "Blur"]


def _make_single_opt_transform(name, rng, params):
    brightness_range: Tuple[float, float] = params.get("brightness_range", (0.7, 1.3))
    contrast_range: Tuple[float, float] = params.get("contrast_range", (0.7, 1.3))
    blur_kernel: int = params.get("blur_kernel", 3)
    blur_sigma: Tuple[float, float] = params.get("blur_sigma", (0.5, 1.5))
    if name == "Brightness":
        factor = float(rng.uniform(brightness_range[0], brightness_range[1]))
        def _f(img):
            return TF.adjust_brightness(img, factor)
        return T.Lambda(_f)
    if name == "Contrast":
        factor = float(rng.uniform(contrast_range[0], contrast_range[1]))
        def _f(img):
            return TF.adjust_contrast(img, factor)
        return T.Lambda(_f)
    if name == "Blur":
        return T.GaussianBlur(kernel_size=blur_kernel, sigma=blur_sigma)
    raise ValueError(name)


def _make_opt_pipeline(rng, params: Dict[str, Any]):
    ops = [_make_single_opt_transform(n, rng, params) for n in _OPT_CANONICAL]
    return T.Compose(ops)


def opt_module(
    X,
    y,
    per_sample=1,
    *,
    brightness_range: Tuple[float, float] = (0.7, 1.3),
    contrast_range: Tuple[float, float] = (0.7, 1.3),
    blur_kernel: int = 3,
    blur_sigma: Tuple[float, float] = (0.5, 1.5),
    rng: Optional[np.random.RandomState] = None,
):
    X = np.asarray(X)
    y = np.asarray(y)
    N = len(X)
    X_out = _alloc_like_batch(X, per_sample)
    y_out = np.empty((N * per_sample,), dtype=y.dtype)
    if rng is None:
        rng = np.random
    opt_params = {
        "brightness_range": brightness_range,
        "contrast_range": contrast_range,
        "blur_kernel": blur_kernel,
        "blur_sigma": blur_sigma,
    }
    k = 0
    for i in range(N):
        img_pil = _to_pil(X[i])
        yi = int(y[i])
        for rep in range(per_sample):
            opt_pipeline = _make_opt_pipeline(rng, opt_params)
            aug = opt_pipeline(img_pil)
            X_out[k] = _from_pil(aug, X[i])
            y_out[k] = yi
            k += 1
    return X_out, y_out


def copy_module(X, y, per_sample=1):
    X = np.asarray(X)
    y = np.asarray(y)
    N = len(X)
    X_out = _alloc_like_batch(X, per_sample)
    y_out = np.empty((N * per_sample,), dtype=y.dtype)
    k = 0
    for i in range(N):
        xi = X[i]
        yi = int(y[i])
        for rep in range(per_sample):
            X_out[k] = xi
            y_out[k] = yi
            k += 1
    return X_out, y_out


def poison_module(
    X,
    y,
    per_sample=1,
    poisoning_prob=0.5,
    rng=None,
    aug_type="geom",
    geom_kwargs: Optional[Dict[str, Any]] = None,
    opt_kwargs: Optional[Dict[str, Any]] = None,
):
    if rng is None:
        rng = np.random
    X = np.asarray(X)
    y = np.asarray(y)
    N = len(X)

    X_out = _alloc_like_batch(X, per_sample)
    y_out = np.empty((N * per_sample,), dtype=y.dtype)
    ind_poisoned = np.empty((N * per_sample,), dtype=bool)
    parent_idx_in_input = np.empty((N * per_sample,), dtype=np.int64)

    num_classes = int(y.max()) + 1
    idx_by_cls = {c: np.where(y == c)[0] for c in range(num_classes)}

    k = 0
    for i in range(N):
        yi = int(y[i])
        donor_classes = [c for c in range(num_classes) if c != yi and len(idx_by_cls[c]) > 0]
        if not donor_classes:
            donor_classes = [yi]
        for rep in range(per_sample):
            poisoned = (rng.rand() < poisoning_prob)
            if poisoned:
                donor_c = int(rng.choice(donor_classes))
                donor_idx = int(rng.choice(idx_by_cls[donor_c]))
            else:
                donor_idx = i

            img_np = X[donor_idx]
            img_pil = _to_pil(img_np)

            if aug_type == "geom":
                geom_pipeline = _make_geom_pipeline(rng, geom_kwargs or {})
                img_pil = geom_pipeline(img_pil)
            elif aug_type == "opt":
                opt_pipeline = _make_opt_pipeline(rng, opt_kwargs or {})
                img_pil = opt_pipeline(img_pil)
            elif aug_type == "geom+opt":
                geom_pipeline = _make_geom_pipeline(rng, geom_kwargs or {})
                opt_pipeline = _make_opt_pipeline(rng, opt_kwargs or {})
                img_pil = opt_pipeline(geom_pipeline(img_pil))
            # "none" -> no transform

            img_np_aug = _from_pil(img_pil, img_np)
            X_out[k] = img_np_aug
            y_out[k] = yi
            ind_poisoned[k] = poisoned
            parent_idx_in_input[k] = int(donor_idx)
            k += 1

    return X_out, y_out, ind_poisoned, parent_idx_in_input


