# forward_forward/data/transforms.py

import torch
import random
from omegaconf import DictConfig
import torchvision.transforms as T
import torchvision.transforms.functional as F

class ApplyTransform(torch.utils.data.Dataset):
    """Helper class to apply different transforms to the same base dataset."""
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.dataset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.dataset)

class RandomFixedRotation:
    def __init__(self, degrees: list[int], p: float = 0.01):
        self.degrees = degrees
        self.p = p

    def __call__(self, img):
        if random.random() < self.p:
            angle = random.choice(self.degrees)
            return F.rotate(img, angle)
        return img

    def __repr__(self):
        return f"{self.__class__.__name__}(degrees={self.degrees}, p={self.p})"

class AddGaussianNoise:
    def __init__(self, mean=0.0, std=0.1):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        return tensor + torch.randn_like(tensor) * self.std + self.mean

    def __repr__(self):
        return f"{self.__class__.__name__}(mean={self.mean}, std={self.std})"

def build_transform(aug_cfg: DictConfig, image_size: int) -> T.Compose:
    """Builds a torchvision transform pipeline from augmentation config."""
    transforms = []

    # Horizontal flip
    hflip_prob = aug_cfg.get("hflip", 0.0)
    if hflip_prob > 0:
        transforms.append(T.RandomHorizontalFlip(p=hflip_prob))

    # Random crop with padding (shift)
    padding = aug_cfg.get("shift", 0)
    if padding > 0:
        transforms.append(T.RandomCrop(image_size, padding=padding, padding_mode="reflect"))

    # Color jitter
    jitter_cfg = aug_cfg.get("color_jitter")
    if isinstance(jitter_cfg, dict):
        transforms.append(T.ColorJitter(
            brightness=jitter_cfg.get("brightness", 0.0),
            contrast=jitter_cfg.get("contrast", 0.0),
            saturation=jitter_cfg.get("saturation", 0.0),
            hue=jitter_cfg.get("hue", 0.0),
        ))

    # Random rotation
    rotation = aug_cfg.get("rotation", 0)
    if rotation > 0:
        transforms.append(T.RandomRotation(degrees=rotation))

    # Optional fixed rotation (90°, 270°) with probability
    fixed_rot = aug_cfg.get("fixed_rot")
    if isinstance(fixed_rot, dict):
        degrees = fixed_rot.get("degrees", [90, 270])
        prob = fixed_rot.get("p", 0.01)
        transforms.append(RandomFixedRotation(degrees, prob))

    # Random grayscale
    grayscale_prob = aug_cfg.get("grayscale", 0.0)
    if grayscale_prob > 0:
        transforms.append(T.RandomGrayscale(p=grayscale_prob))

    transforms.append(T.ToTensor())

    # Optional Gaussian noise
    noise_cfg = aug_cfg.get("gaussian_noise")
    if isinstance(noise_cfg, dict):
        transforms.append(AddGaussianNoise(
            mean=noise_cfg.get("mean", 0.0),
            std=noise_cfg.get("std", 0.0)
        ))

    return T.Compose(transforms)
