from enum import Enum
from dataclasses import dataclass
from typing import Literal
from pathlib import Path
from llm.system_prompts import (
    SYSTEM_PROMPT_SQL,
    SYSTEM_PROMPT_REGEX,
    SYSTEM_PROMPT_FOL,
    SYSTEM_PROMPT_GSM_PLUS,
    SYSTEM_PROMPT_COQ,

    SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_SQL,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_REGEX,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_FOL,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_GSM_PLUS,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_COQ,

    SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_SQL,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_REGEX,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_FOL,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_GSM_PLUS,
    SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_COQ

)


class Domain(Enum):
    SQL = "sql"
    RegEx = "regex"
    GSM_PLUS = "gsm_plus"
    Coq = "coq"
    FOL = "fol"

    def __str__(self) -> str:
        return self.value

    def to_system_prompt(self) -> str:
        match self:
            case Domain.SQL:
                return SYSTEM_PROMPT_SQL
            case Domain.RegEx:
                return SYSTEM_PROMPT_REGEX
            case Domain.GSM_PLUS:
                return SYSTEM_PROMPT_GSM_PLUS
            case Domain.Coq:
                return SYSTEM_PROMPT_COQ
            case Domain.FOL:
                return SYSTEM_PROMPT_FOL
            case _:
                raise ValueError("Unreachable")
            
    def to_system_prompt_for_judge(self, judge_type: Literal['nl', 'nl-base', 'equivalence']) -> str:
        if judge_type == 'nl' or judge_type == 'nl-base':
            match self:
                case Domain.SQL:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_SQL
                case Domain.RegEx:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_REGEX
                case Domain.GSM_PLUS:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_GSM_PLUS
                case Domain.Coq:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_COQ
                case Domain.FOL:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_SCORE_FOL
                case _:
                    raise ValueError("Unreachable")
        elif judge_type == 'equivalence':
            match self:
                case Domain.SQL:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_SQL
                case Domain.RegEx:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_REGEX
                case Domain.GSM_PLUS:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_GSM_PLUS
                case Domain.Coq:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_COQ
                case Domain.FOL:
                    return SYSTEM_PROMPT_LLM_AS_A_JUDGE_EQUIV_FOL
                case _:
                    raise ValueError("Unreachable")
        else:
            raise ValueError()

@dataclass
class Task:
    id_in_domain: int
    domain: Domain
    natural_language: str
    answer: str
    llm_solution: str | None


@dataclass
class SqlTask(Task):
    source: str
    schema: list[dict]
    constraints: list[dict]
    nl_constraints_only: str
    bound: int


@dataclass
class FolTask(Task):
    # [ [name, arity] ]
    symbols: list[list[str | int]]
    nl_symbols_only: str


@dataclass
class GsmPlusTask(Task):
    nl_constraints_only: str
    variable_types: dict[str, Literal["float", "int", "str"]]
    constraints: list[str]
    private_constraints: list[str]
    question_path: Path
    original_id: str
    gsm_mode: Literal["p1", "p2", "symbolic"]


@dataclass
class CoqTask(Task):
    filepath: Path
    full_content: str
    nl_context_only: str


@dataclass
class RegexTask(Task):
    original_id: int
