import json
import datetime
from pathlib import Path

from common.task import Task, CoqTask, FolTask, GsmPlusTask, SqlTask, Domain, RegexTask
from task_loading.coq_file import CoqFile

SCRIPT_DIR = Path(__file__).resolve().parent
DATA_DIR = SCRIPT_DIR.parent.parent / "data"

SQL_JSONL = DATA_DIR / "sql.jsonl"
FOL_JSONL = DATA_DIR / "fol.jsonl"
REGEX_JSONL = DATA_DIR / "regex.jsonl"
GSM_PLUS_DIR = DATA_DIR / "gsm-plus"
COQ_SPECS_DIR = DATA_DIR / "coq-specs"


class DateEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (datetime.date, datetime.datetime)):
            return {"__date__": obj.isoformat()}
        return super().default(obj)


def date_decoder(obj):
    if "__date__" in obj:
        return datetime.date.fromisoformat(obj["__date__"])
    return obj


def load_all_tasks() -> list[Task]:
    tasks: list[Task] = []
    tasks.extend(load_sql_tasks())
    tasks.extend(load_fol_tasks())
    tasks.extend(load_gsm_plus_tasks())
    tasks.extend(load_coq_tasks())
    tasks.extend(load_regex_tasks())
    return tasks


def load_sql_tasks() -> list[SqlTask]:
    tasks = []
    with open(SQL_JSONL, "r") as f:
        for line in f:
            j = json.loads(line, object_hook=date_decoder)
            nl_constraints_only = 'Constraint(s) as JSON: ' + json.dumps(j["constraints"], cls=DateEncoder)
            nl_constraints_only += '\nTable Schema(s) as JSON: ' + json.dumps(j["schema"], cls=DateEncoder)
            task = SqlTask(
                id_in_domain=j["id"],
                domain=Domain.SQL,
                natural_language=j["problem_description"],
                answer=j["solution"],
                source=j["source"],
                schema=j["schema"],
                constraints=j["constraints"],
                nl_constraints_only=nl_constraints_only,
                bound=j["bound"],
                llm_solution=None,
            )
            tasks.append(task)
    return tasks


def load_fol_tasks() -> list[FolTask]:
    with open(FOL_JSONL, "r") as f:
        tasks = []
        for line in f:
            j = json.loads(line, object_hook=date_decoder)
            natural_language = j["natural_language"]

            symbol_descriptions = [f"{symbol[0]} ({symbol[1]})" for symbol in j["symbols"]]
            nl_symbols_only = 'You must only use the following symbols: ' + ', '.join(symbol_descriptions)
            natural_language += ' ' + nl_symbols_only

            task = FolTask(
                id_in_domain=j["id"],
                domain=Domain.FOL,
                natural_language=natural_language,
                nl_symbols_only=nl_symbols_only,
                answer=j["fol_expression"],
                symbols=j["symbols"],
                llm_solution=None,
            )
            tasks.append(task)
    return tasks

def load_gsm_plus_tasks() -> list[GsmPlusTask]:
    tasks = []
    counter = 1
    for mode in ("p1", "p2", "symbolic"):
        mode_dir = GSM_PLUS_DIR / mode
        for path in sorted(mode_dir.glob("*.json")):
            with path.open("r", encoding="utf-8") as f:
                data = json.load(f)

            natural_language = data["question"]
            nl_constraints_only = ''
            if data['constraints']:
                # This only provides public constraints
                natural_language += '\nConstraints: ' + ', '.join(data['constraints'])
                nl_constraints_only += 'Constraints: ' + ', '.join(data['constraints'])

            if data["private_constraints"]:
                if not nl_constraints_only:
                    nl_constraints_only += 'Constraints: '
                else:
                    nl_constraints_only += ', '
                # THEN this provides private constraints, but private
                # constraints are not appended to natural_language
                nl_constraints_only += ', '.join(data['private_constraints'])

            tasks.append(
                GsmPlusTask(
                    id_in_domain=counter,
                    domain=Domain.GSM_PLUS,
                    natural_language=natural_language,
                    nl_constraints_only=nl_constraints_only,
                    answer=data["answer"],
                    variable_types=data["types"],
                    constraints=data["constraints"],
                    private_constraints=data["private_constraints"],
                    question_path=path,
                    original_id=path.name,
                    gsm_mode=mode,
                    llm_solution=None,
                )
            )
            counter += 1
    return tasks


def load_coq_tasks() -> list[CoqTask]:
    tasks = []
    for path in sorted(COQ_SPECS_DIR.glob("*.v")):
        id_in_domain = path.stem.lstrip("p")
        coq_file = CoqFile.from_file(path)

        natural_language, nl_context_only = coq_file.nl_description_prompt()
        answer = coq_file.problem_spec

        # This is not possible unless there's a corruption in the data files!
        assert answer is not None, f"Problem spec missing in {path}"

        # Important!! We have to match to the original signature. The answer
        # signature starts with 'problem_spec', but the generated specifications
        # by the LLM must always have 'generated_spec'. We set the answer
        # to match to that.
        answer = answer.replace('problem_spec', 'generated_spec')

        tasks.append(
            CoqTask(
                id_in_domain=int(id_in_domain),
                domain=Domain.Coq,
                natural_language=natural_language,
                nl_context_only=nl_context_only,
                answer=answer,
                filepath=coq_file.filepath,
                full_content=coq_file.full_content,
                llm_solution=None,
            )
        )
    return tasks


def load_regex_tasks() -> list[RegexTask]:
    tasks = []
    with open(REGEX_JSONL, "r") as f:
        for line in f:
            j = json.loads(line, object_hook=date_decoder)
            task = RegexTask(
                id_in_domain=int(j["id"]),
                original_id=j["original_id"],
                domain=Domain.RegEx,
                natural_language=j["description"],
                answer=j["regex"],
                llm_solution=None,
            )
            tasks.append(task)
    return tasks
