from dataclasses import dataclass

from hydra.core.config_store import ConfigStore
from omegaconf import MISSING


@dataclass
class LossConfig:
    _target_: str = MISSING
    reduction: str = "mean"


@dataclass
class CrossEntropyLossConfig(LossConfig):
    _target_: str = "torch.nn.CrossEntropyLoss"


@dataclass
class BCEWithLogitsLossConfig(LossConfig):
    _target_: str = "torch.nn.BCEWithLogitsLoss"


@dataclass
class MultiLabelMarginLossConfig(LossConfig):
    _target_: str = "torch.nn.MultiLabelMarginLoss"


@dataclass
class MultiLabelSoftMarginLossConfig(LossConfig):
    _target_: str = "torch.nn.MultiLabelSoftMarginLoss"


@dataclass
class MSELossConfig(LossConfig):
    _target_: str = "torch.nn.MSELoss"


@dataclass
class RMSELossConfig(LossConfig):
    _target_: str = "losses.RMSELoss"


@dataclass
class KLDivLossConfig(LossConfig):
    _target_: str = "torch.nn.KLDivLoss"
    reduction: str = "batchmean"


def register_loss_configs() -> None:
    cs = ConfigStore.instance()
    cs.store(
        group="loss",
        name="base_cross_entropy_loss",
        node=CrossEntropyLossConfig,
    )
    cs.store(
        group="loss",
        name="base_bce_with_logits_loss",
        node=BCEWithLogitsLossConfig,
    )
    cs.store(
        group="loss",
        name="base_multilabel_margin_loss",
        node=MultiLabelMarginLossConfig,
    )
    cs.store(
        group="loss",
        name="base_multilabel_soft_margin_loss",
        node=MultiLabelSoftMarginLossConfig,
    )
    cs.store(
        group="loss",
        name="base_mse_loss",
        node=MSELossConfig,
    )
    cs.store(
        group="loss",
        name="base_rmse_loss",
        node=RMSELossConfig,
    )
    cs.store(
        group="loss",
        name="base_kldiv_loss",
        node=KLDivLossConfig,
    )
