from typing import Any, Dict, List, Literal, Optional, Tuple, Union

from pydantic import BaseModel, Field

from nightjar import nj_llm_factory


class Task:
    def __init__(self, task_id: str, description: str, completed: bool):
        self.task_id = task_id
        self.description = description
        self.completed = completed

    def __str__(self):
        return f"('{self.task_id}', '{self.description}', {self.completed})"


class MarkTaskCompleted(BaseModel):
    type: Literal["mark_task_completed"] = "mark_task_completed"
    task_id: str


Command = Union[MarkTaskCompleted]


class TaskCompletionLLMResult(BaseModel):
    commands: List[Command] = Field(default_factory=list)
    message: Optional[str] = None


def main(tasks: List[Task], nj_llm) -> str:
    tasks_block = "\n".join([f"{t.task_id}\t{t.description}\tcompleted={t.completed}" for t in tasks])

    result: TaskCompletionLLMResult = nj_llm(
        "Review the <tasks> list and identify all tasks related to housework "
        "(e.g., cleaning, chores, home maintenance, laundry, dishes, vacuuming). "
        "For each housework-related task, add a `MarkTaskCompleted` command with the "
        "matching task_id. If a natural-language response is appropriate, include it in "
        "`message`; otherwise set `message` to None.\n"
        f"<tasks>{tasks_block}</tasks>",
        output_format=TaskCompletionLLMResult,
    )

    # Execute structured commands
    for cmd in result.commands:
        if isinstance(cmd, MarkTaskCompleted):
            for task in tasks:
                if task.task_id == cmd.task_id:
                    task.completed = True
                    break

    return result.message or ""


#### Tests ####


def run(
    model_name: str,
) -> Tuple[Dict[str, Tuple[Any, Any]], Dict[str, Any], Dict[str, bool], Dict[str, str]]:
    nj_llm, usage = nj_llm_factory(model_name, max_calls=100)
    tasks_data = [
        ("T001", "Clean the house", False),
        ("T002", "Prepare dinner", True),
        ("T003", "Go to gym", False),
    ]
    tasks = [Task(task_id, desc, done) for task_id, desc, done in tasks_data]
    initial_tasks = tasks.copy()
    outputs = {}
    errors = {}
    hard_results = {
        "test_0_completion_0": False,
        "test_0_completion_1": False,
        "test_0_completion_2": False,
    }

    try:
        main(tasks, nj_llm)
        outputs["test_0"] = tasks
    except Exception as e:
        errors["test_0"] = e
    else:
        try:
            hard_results["test_0_completion_0"] = tasks[0].completed == True
        except Exception as e:
            errors[f"test_0"] = e
        try:
            hard_results[f"test_0_completion_1"] = tasks[1].completed == True
        except Exception as e:
            errors[f"test_0"] = e
        try:
            hard_results[f"test_0_completion_2"] = tasks[2].completed == False
        except Exception as e:
            errors[f"test_0"] = e

    return outputs, errors, hard_results, usage


if __name__ == "__main__":
    results, errors, hard_results, _ = run()
    print(results)
    print(hard_results)
    print(errors)
