from dataclasses import dataclass, field
from typing import Dict, Optional, Tuple, Union

from omegaconf import MISSING


@dataclass
class AugmentationConfig:
    name: str = MISSING
    group: str = "augmentation"
    knn_embeddings_with_identity: bool = False
    test_time_augmentations: Union[int, None] = None
    same_on_batch_works: bool = True
    augmentation_type: str = "kornia"
    align_content_augmentations: bool = False


@dataclass
class SimclrAugmentationConfig(AugmentationConfig):
    name: str = "simclr_augmentations"
    s: float = 1.0
    scale_lb: float = 0.08
    scale_ub: float = 1.0
    csi_eval: bool = False
    use_blur: bool = True
    use_color_jitter: bool = True
    color_jitter_params: Dict[str, float] = field(
        default_factory=lambda: dict(
            {
                "p": 0.8,
                "saturation": 0.8,
                "brightness": 0.8,
                "contrast": 0.8,
                "hue": 0.2,
            }
        )
    )
    use_horizontal_flip: bool = True
    test_time_augmentations: int = 40


@dataclass
class RandomRotationConfig(AugmentationConfig):
    name: str = "random_rotation"
    max_180_degrees: bool = False
    angle_offset: float = 0.0


@dataclass
class RandomInvertConfig(AugmentationConfig):
    name: str = "random_invert"
    augmentation_type: str = "custom"


@dataclass
class RandomEqualizeConfig(AugmentationConfig):
    name: str = "random_equalize"
    clahe: bool = False
    augmentation_type: str = "custom"


@dataclass
class RandomGrayscaleConfig(AugmentationConfig):
    name: str = "random_grayscale"
    p: float = 0.2


@dataclass
class GaussianBlurConfig(AugmentationConfig):
    name: str = "random_gaussian_blur"
    size: int = MISSING


@dataclass
class RandomAffineConfig(AugmentationConfig):
    name: str = "random_affine"
    degrees: int = 0
    translate: Optional[float] = None
    scale: Optional[Tuple[float]] = None  # (0.08, 1.0)
    shear: Optional[int] = None


@dataclass
class RandomPerspectiveConfig(AugmentationConfig):
    name: str = "random_perspective"
    distortion_scale: float = 0.5


@dataclass
class RandomLowFrequencyContextTransformConfig(AugmentationConfig):
    name: str = "low_frequency_context_augmentation"
    box_size: int = 8
    augmentation_type: str = "custom"


@dataclass
class RandomPhaseScrambleConfig(AugmentationConfig):
    name: str = "random_phase_scramble_augmentation"
    box_size: int = 4
    augmentation_type: str = "custom"


@dataclass
class RandomIntensityTranslationConfig(AugmentationConfig):
    name: str = "random_intensity_translation"
    augmentation_type: str = "custom"


@dataclass
class ContextFlipConfig(AugmentationConfig):
    name: str = "random_flip"
    augmentation_type: str = "custom"


@dataclass
class ContextFlipInvertConfig(AugmentationConfig):
    name: str = "random_flip_invert"
    augmentation_type: str = "custom"


@dataclass
class ContextFlipEqualizeConfig(AugmentationConfig):
    name: str = "random_flip_equalize"
    augmentation_type: str = "custom"


@dataclass
class ContextInvertEqualizeConfig(AugmentationConfig):
    name: str = "random_invert_equalize"
    augmentation_type: str = "custom"


@dataclass
class ContextFlipInvertEqualizeConfig(AugmentationConfig):
    name: str = "random_flip_invert_equalize"
    augmentation_type: str = "custom"
