from pydantic import model_validator

from src.settings.base import ExtraFieldsNotAllowedBaseModel


class GeneratorTransformersSettings(ExtraFieldsNotAllowedBaseModel):
    num_beams: int = 1
    max_new_tokens: int = 15
    repetition_penalty: float = 1.0
    num_return_sequences: int = 1
    do_sample: bool = True
    top_p: float = 1.0
    top_k: int = 50
    temperature: float = 1.0
    stop_strings: str | list[str] = '</s>'
    stop_token_ids: list[int] | None = None
    return_logits: bool = False
    logprobs: int | None = None

    @model_validator(mode='after')
    def correct_dataset_sampling_values(self) -> 'GeneratorTransformersSettings':
        if self.return_logits and self.logprobs is None:
            raise ValueError("If 'return_logits' is True, 'logprobs' must be specified.")
        if not self.return_logits and self.logprobs is not None:
            raise ValueError("If 'return_logits' is False, 'logprobs' should not be specified.")

        return self
