import numpy as np
# Fix numpy deprecated aliases for imgaug compatibility
np.bool = bool  # alias deprecated np.bool
np.complex = complex  # alias deprecated np.complex

from PIL import Image
import imgaug.augmenters as iaa
import torchvision.transforms as T


class ImgAugTransform:
    def __init__(self, aug):
        self.aug = aug
    def __call__(self, img: Image.Image) -> Image.Image:
        arr = np.array(img)
        auged = self.aug.augment_image(arr)
        return Image.fromarray(auged)

# Spatial augment options (OneOf CutOut, MotionBlur, Perspective)
spatial_augmenter = iaa.OneOf([
    iaa.Cutout(nb_iterations=1, size=0.2, squared=True),
    iaa.MotionBlur(k=(3, 7)),
    iaa.PerspectiveTransform(scale=(0.01, 0.1)),
])
spatial_transform = ImgAugTransform(spatial_augmenter)
spatial_apply = T.RandomApply([T.Lambda(lambda img: spatial_transform(img))], p=0.5)

# Weather selector for channel branch (Fog, Snow, Rain)
weather_augmenter = iaa.OneOf([
    iaa.Fog(),
    iaa.Snowflakes(),
    iaa.Rain(),
])

# Channel augment options (OneOf Grayscale, GammaContrast, Weather)
channel_augmenter = iaa.OneOf([
    iaa.GammaContrast((0.8, 1.5)),
    weather_augmenter,
])
channel_transform = ImgAugTransform(channel_augmenter)
channel_apply = T.RandomApply([T.Lambda(lambda img: channel_transform(img))], p=0.5)

def get_train_transform(flag: str, image_size: tuple, mean: list, std: list) -> T.Compose:
    """
    Returns a torchvision Compose transform based on the given flag.
    Flags:
      - 'n': noaugmentation
      - 'r': randaugmentation
      - 'rs': randaugmentation + spatial
      - 'rc': randaugmentation + channel
      - 'rsc': randaugmentation + spatial + channel
    """
    transforms = [
        T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR)
    ]
    if flag == 'n':
        # No augmentation
        pass
    elif flag == 'r':
        transforms.append(
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR)
        )
    elif flag == 'rs':
        transforms.extend([
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            spatial_apply
        ])
    elif flag == 'rc':
        transforms.extend([
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            channel_apply
        ])
    elif flag == 'rsc':
        transforms.extend([
            # T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            spatial_apply,
            channel_apply
        ])
    elif flag == 'rsc15':
        transforms.extend([
            # T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR),
            T.RandAugment(num_ops=3, interpolation=T.InterpolationMode.BILINEAR, magnitude=15),
            spatial_apply,
            channel_apply
        ])
    else:
        raise ValueError(f"Unknown augmentation flag: {flag}")

    transforms.extend([
        T.ToTensor(),
        T.Normalize(mean=mean, std=std)
    ])
    return T.Compose(transforms)
