from enum import Enum
from typing import Optional

from models import ChatModel, ModelType

from ..tasks import Task, TaskType
from .coding_paraphraser import CodingParaphraser
from .cot_qa_paraphrasers import (
    CoTQAContextAwareLLMParaphraser,
    CoTQALLMParaphraser,
)
from .llm_paraphraser import LLMParaphraser
from .paraphraser import Paraphraser

__all__ = [
    "Paraphraser",
    "LLMParaphraser",
    "CoTQALLMParaphraser",
    "CoTQAContextAwareLLMParaphraser",
    "CodingParaphraser",
]


class ParaphraserType(Enum):
    NoParaphraser = "np"
    StandardParaphraser = "sp"
    ContextAwareParaphraser = "cap"


def get_task_paraphraser(
    task: Task, paraphraser_type: ParaphraserType, *args, **kwargs
) -> Optional[Paraphraser]:
    match paraphraser_type:
        case ParaphraserType.NoParaphraser:
            return None
        case ParaphraserType.StandardParaphraser:
            match task.TASK_TYPE:
                case TaskType.SIMPLE_QA:
                    return LLMParaphraser(*args, **kwargs)
                case TaskType.MCQ_COT_QA:
                    return CoTQALLMParaphraser(*args, **kwargs)
                case TaskType.CODING:
                    return CodingParaphraser(*args, **kwargs)
        case ParaphraserType.ContextAwareParaphraser:
            match task.TASK_TYPE:
                case TaskType.MCQ_COT_QA:
                    return CoTQAContextAwareLLMParaphraser(*args, **kwargs)
                case _:
                    raise ValueError(
                        f"No context-aware paraphraser for task type {task.TASK_TYPE}"
                    )


def get_task_paraphraser_by_model(
    task: Task,
    paraphraser_type: ParaphraserType,
    model: ModelType | ChatModel,
    *args,
    **kwargs,
) -> Optional[Paraphraser]:
    match paraphraser_type:
        case ParaphraserType.NoParaphraser:
            return None
        case ParaphraserType.StandardParaphraser:
            match task.TASK_TYPE:
                case TaskType.SIMPLE_QA:
                    return LLMParaphraser(model=model)
                case TaskType.MCQ_COT_QA:
                    return CoTQALLMParaphraser(model=model)
                case TaskType.CODING:
                    return CodingParaphraser(*args, **kwargs)
        case ParaphraserType.ContextAwareParaphraser:
            match task.TASK_TYPE:
                case TaskType.MCQ_COT_QA:
                    return CoTQAContextAwareLLMParaphraser(model=model)
                case _:
                    raise ValueError(
                        f"No context-aware paraphraser for task type {task.TASK_TYPE}"
                    )


def get_paraphrasing_type_by_paraphraser(paraphraser: Paraphraser) -> ParaphraserType:
    if isinstance(paraphraser, CoTQAContextAwareLLMParaphraser):
        return ParaphraserType.ContextAwareParaphraser
    elif (
        isinstance(paraphraser, LLMParaphraser)
        or isinstance(paraphraser, CoTQALLMParaphraser)
        or isinstance(paraphraser, CodingParaphraser)
    ):
        return ParaphraserType.StandardParaphraser
    elif paraphraser is None:
        return ParaphraserType.NoParaphraser
    else:
        raise ValueError(f"Unknown paraphraser type: {paraphraser}")
