from pathlib import Path
from typing import Annotated, Sequence, Optional

from pydantic import Field

from src.settings.base import ExtraFieldsNotAllowedBaseModel
from src.settings.datasets.chat import ChatMultiDatasetSettings
from src.settings.model import PreTrainedAdaptersModelSettings, PreTrainedModelSettings
from src.settings.tf.tokenizer import TokenizerSettings

INFERENCE_DATASETS_SETTINGS = Annotated[
    ChatMultiDatasetSettings,
    Field(discriminator='dataset_type'),
]


class SingleModelInferenceSettings(ExtraFieldsNotAllowedBaseModel):
    model_settings: PreTrainedAdaptersModelSettings | PreTrainedModelSettings
    sft_settings: Optional[PreTrainedAdaptersModelSettings | PreTrainedModelSettings] = None
    tokenizer_settings: TokenizerSettings
    use_vllm: bool = False
    batch: int = 1
    micro_batch: int = 1


class InferenceExperimentSettings(ExtraFieldsNotAllowedBaseModel):
    inference_settings: Sequence[SingleModelInferenceSettings]

    dataset_settings: INFERENCE_DATASETS_SETTINGS
    save_path: Path
