from typing import Optional

from pydantic import Field, BaseModel, ConfigDict

from external.edm.training.networks import EDMPrecond
from src.models.models.edm import create_edm_model
from src.utils.load import load_from_state_dict


class EDMModelConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    name: str = Field()
    load_path: Optional[str] = Field(default=None)
    load_keys: list[str] = Field(default_factory=list)

    def get_model(self) -> EDMPrecond:
        generator: EDMPrecond = create_edm_model(self.name)
        return load_from_state_dict(
            entity=generator,
            load_path=self.load_path,
            load_keys=self.load_keys
        )
