from typing import Any

import torch
from pydantic import BaseModel, Field, ConfigDict

from src.losses.losses import get_loss


class LossConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    name: str = Field()
    params: dict[str, Any] = Field(default_factory=dict)

    def get_loss(self) -> torch.nn.Module:
        return get_loss(self.name, self.params)
