import torch.nn
from pydantic import BaseModel, ConfigDict, Field

from src.models.discriminators.style_gan_xl import StyleGANXLDiscriminator, StyleGANXLDiscriminatorFeatureExtractor, \
    create_style_gan_xl_discriminator_and_discriminator_feature_extractor
from src.utils.load import load_from_state_dict


class StyleGANXLDiscriminatorConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    discriminator_load_path: str = Field(default=None)
    discriminator_load_keys: list[str] = Field(default_factory=list)
    discriminator_feature_extractor_load_path: str = Field(default=None)
    discriminator_feature_extractor_load_keys: list[str] = Field(default_factory=list)
    conditional: bool = Field()
    c_out: int = Field(default=64)
    prob_aug: float = Field(default=1.0)
    shift_ratio: float = Field(default=0.125)
    cutout_ratio: float = Field(default=0.2)

    def get_discriminator(self, image_size: int) -> \
            tuple[StyleGANXLDiscriminator, StyleGANXLDiscriminatorFeatureExtractor]:
        discriminator_dict, discriminator_feature_extractor_dict = \
            create_style_gan_xl_discriminator_and_discriminator_feature_extractor(
                image_size, self.conditional, c_out=self.c_out)
        discriminator_dict: torch.nn.ModuleDict = load_from_state_dict(
            entity=discriminator_dict,
            load_path=self.discriminator_load_path,
            load_keys=self.discriminator_load_keys
        )
        discriminator_feature_extractor_dict: torch.nn.ModuleDict = load_from_state_dict(
            entity=discriminator_feature_extractor_dict,
            load_path=self.discriminator_feature_extractor_load_path,
            load_keys=self.discriminator_feature_extractor_load_keys
        )
        discriminator_feature_extractor: StyleGANXLDiscriminatorFeatureExtractor = \
            StyleGANXLDiscriminatorFeatureExtractor(discriminator_feature_extractor_dict)
        discriminator: StyleGANXLDiscriminator = StyleGANXLDiscriminator(discriminator_dict)
        return discriminator, discriminator_feature_extractor
