from pathlib import Path

from models import ModelType

from ..payloads import Payload
from ..tasks import Task
from .llm_overseer import (
    LLMClusteringOverseer,
    LLMPayloadDetectorOverseer,
    LLMPayloadGuesserOverseer,
    SusRatingType,
)
from .overseer import Overseer, OverseerEvaluation, OverseerType

__all__ = [
    "Overseer",
    "CodingOverseer",
    "get_task_overseer",
    "OverseerEvaluation",
    "get_overseer_by_type",
    "OverseerType",
    "SusRatingType",
    "LLMClusteringOverseer",
]


def get_task_overseer(
    model_type: ModelType,
    payload: Payload,
    task: Task,
    *args,
    clustering=True,
    **kwargs,
) -> Overseer:
    return (
        LLMClusteringOverseer(model_type, payload, task, *args, **kwargs)
        if clustering
        else LLMPayloadGuesserOverseer(model_type, payload, task, *args, **kwargs)
    )


def get_overseer_by_type(
    model_type: ModelType,
    payload: Payload,
    task: Task,
    overseer_type: OverseerType,
    save_dir: Path = None,
    sus_rating_type: SusRatingType = SusRatingType.SCORE,
    sus_threshold: int = 3,
    *args,
    **kwargs,
) -> Overseer:
    overseer = None
    match overseer_type:
        case OverseerType.PAYLOAD_DETECTION:
            assert (
                save_dir is not None
            ), "Save directory must be provided for payload detection overseer"
            overseer = LLMPayloadDetectorOverseer(
                model_type,
                payload,
                task,
                save_dir=save_dir,
                sus_rating_type=sus_rating_type,
                sus_threshold=sus_threshold,
            )
        case OverseerType.CLUSTERING:
            overseer = LLMClusteringOverseer(model_type, payload, task, *args, **kwargs)
        case OverseerType.PAYLOAD_GUESSER:
            overseer = LLMPayloadGuesserOverseer(
                model_type, payload, task, *args, **kwargs
            )
        case _:
            raise ValueError("Overseer type not valid")
    return overseer
