from pydantic import BaseModel, Field, ConfigDict, model_validator


class NumpyFolderConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    folder: str = Field()
    num_samples: int = Field()
    start_index: int = Field(default=0)


class NumpyFolderNoiseImageSigmaConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    noise: NumpyFolderConfig = Field()
    image: NumpyFolderConfig = Field()
    time_step: int = Field()

    @model_validator(mode='after')
    def validate_lengths(self):
        assert self.noise.num_samples == self.image.num_samples, (
            f'noise and image must have the same number of samples: '
            f'{self.noise.num_samples} != {self.image.num_samples}'
        )
        return self

    @property
    def num_samples(self) -> int:
        return self.noise.num_samples


class NumpyFolderNoiseImageLabelSigmaConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    noise: NumpyFolderConfig = Field()
    image: NumpyFolderConfig = Field()
    label: NumpyFolderConfig = Field()
    time_step: int = Field()

    @model_validator(mode='after')
    def validate_lengths(self):
        assert self.noise.num_samples == self.image.num_samples == self.label.num_samples, (
            f'noise, image and label must have the same number of samples: '
            f'{self.noise.num_samples} != {self.image.num_samples} != {self.label.num_samples}'
        )
        return self

    @property
    def num_samples(self) -> int:
        return self.noise.num_samples


class NumpyFolderNoiseConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    noise: NumpyFolderConfig = Field()

    @property
    def num_samples(self) -> int:
        return self.noise.num_samples


class NumpyFolderNoiseLabelConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    noise: NumpyFolderConfig = Field()
    label: NumpyFolderConfig = Field()

    @model_validator(mode='after')
    def validate_lengths(self):
        assert self.noise.num_samples == self.label.num_samples, (
            f'noise and label must have the same number of samples: '
            f'{self.noise.num_samples} != {self.label.num_samples}'
        )
        return self

    @property
    def num_samples(self) -> int:
        return self.noise.num_samples


class NumpyFolderImageConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    image: NumpyFolderConfig = Field()

    @property
    def num_samples(self) -> int:
        return self.image.num_samples


class NumpyFolderImageLabelConfig(BaseModel):
    model_config: ConfigDict = ConfigDict(frozen=True, strict=True, validate_assignment=True, extra='forbid')
    image: NumpyFolderConfig = Field()
    label: NumpyFolderConfig = Field()

    @model_validator(mode='after')
    def validate_lengths(self):
        assert self.image.num_samples == self.label.num_samples, (
            f'image and label must have the same number of samples: '
            f'{self.image.num_samples} != {self.label.num_samples}'
        )
        return self

    @property
    def num_samples(self) -> int:
        return self.image.num_samples
