from typing_extensions import Self
from typing import Literal, Optional
from pydantic import model_validator

from src.settings.base import ExtraFieldsNotAllowedBaseModel
from src.settings.datasets.base import BaseDatasetSettings, DatasetType, MultiDatasetSettings


class ChatPromptTemplate(ExtraFieldsNotAllowedBaseModel):
    prefix_template: Optional[str] = None
    suffix_template: Optional[str] = None
    role_tag_mapping: Optional[dict[str, str]] = {}

    role_prefix_templates: Optional[dict[str, str]] = {}
    role_suffix_templates: Optional[dict[str, str]] = {}

    @model_validator(mode="after")
    def check_prefix_and_suffix(self) -> Self:
        prefix_template = self.prefix_template
        role_prefix_templates = self.role_prefix_templates
        suffix_template = self.suffix_template
        role_suffix_templates = self.role_suffix_templates

        if not prefix_template and not role_prefix_templates:
            raise ValueError("You need to set either global 'prefix_template' or 'role_prefix_templates'.")
        if not suffix_template and not role_suffix_templates:
            raise ValueError("You need to set either global 'suffix_template' or 'role_suffix_templates'.")
        return self


class ChatDatasetSettings(BaseDatasetSettings):
    dataset_type: Literal[DatasetType.CHAT] = DatasetType.CHAT

    only_last_replica_loss: bool = False
    only_answer_loss: bool = True
    random_cut: bool = False

    keep_end: bool | None = None
    max_tokens_count: int | None
    prompt_template: ChatPromptTemplate
    ignore_system_prompt: bool = False


class ChatMultiDatasetSettings(ChatDatasetSettings, MultiDatasetSettings): ...
