from dataclasses import dataclass
from typing import Optional

from omegaconf import MISSING

from .augmentations import *
from .enums import *


@dataclass
class LossConfig:
    name: str = MISSING
    group: str = "loss"


@dataclass(eq=False)
class ContextContrastingLossConfig(LossConfig):
    name: str = "context_contrasting_loss"
    utils: AnomalyLoss = AnomalyLoss.context_contrasting_loss
    n_context_augs: int = 2
    augmentation_class: AugmentationConfig = MISSING
    similarity_metric: str = "cos"
    content_loss_mode: str = "separate"
    content_loss_weight: float = 1.0
    content_annealing_start: Optional[float] = 0.0
    content_annealing_epochs: Optional[int] = 2048
    content_annealing_schedule: str = "linear"
    align_contexts: bool = False
    val_loss_batch_size: int = 32
    use_true_negatives: bool = False
    projection_dim: Optional[int] = 128
    use_mean: bool = False
    temperature: float = 0.07
    positive_free_denominator: bool = False
    compact_clusters: bool = False
    hierarchy: bool = False
