from datetime import datetime, timedelta
from typing import Optional
from uuid import UUID

import oasst_backend.models.db_payload as db_payload
from loguru import logger
from oasst_backend.config import settings
from oasst_backend.models import ApiClient, Task
from oasst_backend.models.payload_column_type import PayloadContainer
from oasst_backend.user_repository import UserRepository
from oasst_backend.utils.database_utils import CommitMode, managed_tx_method
from oasst_shared.exceptions.oasst_api_error import OasstError, OasstErrorCode
from oasst_shared.schemas import protocol as protocol_schema
from oasst_shared.utils import utcnow
from sqlmodel import Session, delete, false, func, not_, or_
from starlette.status import HTTP_404_NOT_FOUND


def validate_frontend_message_id(message_id: str) -> None:
    # TODO: Should it be replaced with fastapi/pydantic validation?
    if not isinstance(message_id, str):
        raise OasstError(
            f"message_id must be string, not {type(message_id)}", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID
        )
    if not message_id:
        raise OasstError("message_id must not be empty", OasstErrorCode.INVALID_FRONTEND_MESSAGE_ID)


def delete_expired_tasks(session: Session) -> int:
    stm = delete(Task).where(Task.expiry_date < utcnow(), Task.done == false())
    result = session.exec(stm)
    logger.info(f"Deleted {result.rowcount} expired tasks.")
    return result.rowcount


class TaskRepository:
    def __init__(
        self,
        db: Session,
        api_client: ApiClient,
        client_user: Optional[protocol_schema.User],
        user_repository: UserRepository,
    ):
        self.db = db
        self.api_client = api_client
        self.user_repository = user_repository
        self.user = self.user_repository.lookup_client_user(client_user, create_missing=True)
        self.user_id = self.user.id if self.user else None

    def store_task(
        self,
        task: protocol_schema.Task,
        message_tree_id: UUID = None,
        parent_message_id: UUID = None,
        collective: bool = False,
    ) -> Task:
        payload: db_payload.TaskPayload
        match type(task):
            case protocol_schema.SummarizeStoryTask:
                payload = db_payload.SummarizationStoryPayload(story=task.story)

            case protocol_schema.RateSummaryTask:
                payload = db_payload.RateSummaryPayload(
                    full_text=task.full_text, summary=task.summary, scale=task.scale
                )

            case protocol_schema.InitialPromptTask:
                payload = db_payload.InitialPromptPayload(hint=task.hint)

            case protocol_schema.PrompterReplyTask:
                payload = db_payload.PrompterReplyPayload(conversation=task.conversation, hint=task.hint)

            case protocol_schema.AssistantReplyTask:
                payload = db_payload.AssistantReplyPayload(type=task.type, conversation=task.conversation)

            case protocol_schema.RankInitialPromptsTask:
                payload = db_payload.RankInitialPromptsPayload(type=task.type, prompt_messages=task.prompt_messages)

            case protocol_schema.RankPrompterRepliesTask:
                payload = db_payload.RankPrompterRepliesPayload(
                    type=task.type,
                    conversation=task.conversation,
                    reply_messages=task.reply_messages,
                    ranking_parent_id=task.ranking_parent_id,
                    message_tree_id=task.message_tree_id,
                    reveal_synthetic=task.reveal_synthetic,
                )

            case protocol_schema.RankAssistantRepliesTask:
                payload = db_payload.RankAssistantRepliesPayload(
                    type=task.type,
                    conversation=task.conversation,
                    reply_messages=task.reply_messages,
                    ranking_parent_id=task.ranking_parent_id,
                    message_tree_id=task.message_tree_id,
                    reveal_synthetic=task.reveal_synthetic,
                )

            case protocol_schema.LabelInitialPromptTask:
                payload = db_payload.LabelInitialPromptPayload(
                    type=task.type,
                    message_id=task.message_id,
                    prompt=task.prompt,
                    valid_labels=task.valid_labels,
                    mandatory_labels=task.mandatory_labels,
                    mode=task.mode,
                )

            case protocol_schema.LabelPrompterReplyTask:
                payload = db_payload.LabelPrompterReplyPayload(
                    type=task.type,
                    message_id=task.message_id,
                    conversation=task.conversation,
                    valid_labels=task.valid_labels,
                    mandatory_labels=task.mandatory_labels,
                    mode=task.mode,
                )

            case protocol_schema.LabelAssistantReplyTask:
                payload = db_payload.LabelAssistantReplyPayload(
                    type=task.type,
                    message_id=task.message_id,
                    conversation=task.conversation,
                    valid_labels=task.valid_labels,
                    mandatory_labels=task.mandatory_labels,
                    mode=task.mode,
                )

            case _:
                raise OasstError(f"Invalid task type: {type(task)=}", OasstErrorCode.INVALID_TASK_TYPE)

        if not collective and settings.TASK_VALIDITY_MINUTES > 0:
            expiry_date = utcnow() + timedelta(minutes=settings.TASK_VALIDITY_MINUTES)
        else:
            expiry_date = None

        task_model = self.insert_task(
            payload=payload,
            id=task.id,
            message_tree_id=message_tree_id,
            parent_message_id=parent_message_id,
            collective=collective,
            expiry_date=expiry_date,
        )
        assert task_model.id == task.id
        return task_model

    @managed_tx_method(CommitMode.COMMIT)
    def bind_frontend_message_id(self, task_id: UUID, frontend_message_id: str) -> None:
        validate_frontend_message_id(frontend_message_id)

        # find task
        task: Task = self.db.query(Task).filter(Task.id == task_id, Task.api_client_id == self.api_client.id).first()
        if task is None:
            raise OasstError(f"Task for {task_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND)

        if task.ack and task.frontend_message_id == frontend_message_id:
            return  # ACK is idempotent if called with the same frontend_message_id

        if task.expired:
            raise OasstError("Task already expired.", OasstErrorCode.TASK_EXPIRED)

        if task.done or task.ack is not None:
            raise OasstError("Task already updated.", OasstErrorCode.TASK_ALREADY_UPDATED)

        task.frontend_message_id = frontend_message_id
        task.ack = True
        self.db.add(task)

    @managed_tx_method(CommitMode.COMMIT)
    def close_task(self, frontend_message_id: str, allow_personal_tasks: bool = False):
        """
        Mark task as done. No further messages will be accepted for this task.
        """
        validate_frontend_message_id(frontend_message_id)
        task = self.task_repository.fetch_task_by_frontend_message_id(frontend_message_id)

        if not task:
            raise OasstError(
                f"Task for {frontend_message_id=} not found", OasstErrorCode.TASK_NOT_FOUND, HTTP_404_NOT_FOUND
            )
        if task.expired:
            raise OasstError("Task already expired", OasstErrorCode.TASK_EXPIRED)
        if not allow_personal_tasks and not task.collective:
            raise OasstError("This is not a collective task", OasstErrorCode.TASK_NOT_COLLECTIVE)
        if task.done:
            raise OasstError("Already closed", OasstErrorCode.TASK_ALREADY_DONE)

        task.done = True
        self.db.add(task)

    @managed_tx_method(CommitMode.COMMIT)
    def insert_task(
        self,
        payload: db_payload.TaskPayload,
        id: UUID = None,
        message_tree_id: UUID = None,
        parent_message_id: UUID = None,
        collective: bool = False,
        expiry_date: datetime = None,
    ) -> Task:
        c = PayloadContainer(payload=payload)
        task = Task(
            id=id,
            user_id=self.user_id,
            payload_type=type(payload).__name__,
            payload=c,
            api_client_id=self.api_client.id,
            message_tree_id=message_tree_id,
            parent_message_id=parent_message_id,
            collective=collective,
            expiry_date=expiry_date,
        )
        logger.debug(f"inserting {task=}")
        self.db.add(task)
        return task

    def fetch_task_by_frontend_message_id(self, message_id: str) -> Task:
        validate_frontend_message_id(message_id)
        task = (
            self.db.query(Task)
            .filter(Task.api_client_id == self.api_client.id, Task.frontend_message_id == message_id)
            .one_or_none()
        )
        return task

    def fetch_task_by_id(self, task_id: UUID) -> Task:
        task = self.db.query(Task).filter(Task.api_client_id == self.api_client.id, Task.id == task_id).one_or_none()
        return task

    def fetch_recent_reply_tasks(
        self,
        max_age: timedelta = timedelta(minutes=5),
        done: bool = False,
        skipped: bool = False,
        limit: int = 100,
    ) -> list[Task]:
        qry = self.db.query(Task).filter(
            Task.created_date > func.current_timestamp() - max_age,
            or_(Task.payload_type == "AssistantReplyPayload", Task.payload_type == "PrompterReplyPayload"),
        )
        if done is not None:
            qry = qry.filter(Task.done == done)
        if skipped is not None:
            qry = qry.filter(Task.skipped == skipped)
        if limit:
            qry = qry.limit(limit)
        return qry.all()

    def delete_expired(self) -> int:
        return delete_expired_tasks(self.db)

    def fetch_pending_tasks_of_user(
        self,
        user_id: UUID,
        max_age: timedelta = timedelta(minutes=5),
        limit: int = 100,
    ) -> list[Task]:
        qry = (
            self.db.query(Task)
            .filter(
                Task.user_id == user_id,
                Task.created_date > func.current_timestamp() - max_age,
                not_(Task.done),
                not_(Task.skipped),
            )
            .order_by(Task.created_date)
        )
        if limit:
            qry = qry.limit(limit)
        return qry.all()
