# organisation/env/clinical_trial/core/memory.py

from typing import Deque, List, Optional
from collections import deque
import logging

from .messages import Message
from .llm_client import get_llm_chat, create_llm_client
from organisation.env.config import LLM_ENGINE

logger = logging.getLogger(__name__)

# One shared client for memory utilities (Qwen/vLLM path requires a client)
# If you ever switch to Azure here, pass credentials via env/config and recreate.
_MEMORY_LLM_CLIENT = create_llm_client(LLM_ENGINE)


class RelevanceChecker:
    SYSTEM_PROMPT = """You are a helpful assistant.
Read a chunk of text (Context) and an Agenda.
Return EITHER:
  - the EXACT relevant snippet copied from Context (no extra words), OR
  - the single word NULL (uppercase) if nothing is relevant.
Never return MATCH, MATCHES, MATCHED, YES, RELEVANT, or TRUE.
"""

    @staticmethod
    def check(context: str, agenda: str) -> Optional[str]:
        messages = [
            {"role": "system", "content": RelevanceChecker.SYSTEM_PROMPT},
            {"role": "user", "content": f"Context:\n{context}\n\nAgenda: {agenda}"},
        ]
        resp = get_llm_chat(
            client=_MEMORY_LLM_CLIENT,
            messages=messages,
            tools=None,
            tool_choice="none",
            engine_name=LLM_ENGINE,
        )
        text = (resp.choices[0].message.content or "").strip()
        # strip simple quotes/fences the model might add
        cleaned = text.strip("`\"' ")
        upper = cleaned.upper()
        # Treat common non-compliant tokens as "no match"
        if upper in {"NULL", "MATCH", "MATCHES", "MATCHED", "YES", "RELEVANT", "TRUE"}:
            return (
                None
                if upper == "NULL"
                or upper in {"MATCH", "MATCHES", "MATCHED", "YES", "RELEVANT", "TRUE"}
                else cleaned
            )
        return cleaned


class MessageBuffer:
    """
    Simple fixed-length buffer for holding the most recent messages.
    """

    def __init__(self, maxlen: int):
        self._buffer: Deque[Message] = deque(maxlen=maxlen)

    def add(self, msg: Message) -> None:
        self._buffer.append(msg)

    def is_full(self) -> bool:
        return len(self._buffer) == self._buffer.maxlen

    def flush(self) -> List[Message]:
        msgs = list(self._buffer)
        self._buffer.clear()
        return msgs

    def peek_all(self) -> List[Message]:
        return list(self._buffer)


class Summarizer:
    """
    Wraps LLM summarization so it can be tested or swapped out independently.
    """

    SYSTEM_PROMPT = """\
You are a helpful assistant. Your role is to summarize messages while retaining relevant information. \
You will receive a list of messages and a summary. \
Your task is to produce a new summary that includes the relevant information from the messages. \
Always include key informations such as study ID or results. \
If the messages contain new information, include it; otherwise return the existing summary unchanged.
"""

    @staticmethod
    def summarize(previous_summary: str, messages: List[Message]) -> str:
        batch_text = "\n".join(repr(m) for m in messages)
        llm_messages = [
            {"role": "system", "content": Summarizer.SYSTEM_PROMPT},
            {
                "role": "user",
                "content": f"Previous summary:\n{previous_summary}\n\nMessages:\n{batch_text}",
            },
        ]
        try:
            resp = get_llm_chat(
                client=_MEMORY_LLM_CLIENT,
                messages=llm_messages,
                tools=None,
                tool_choice="none",
                engine_name=LLM_ENGINE,
            )
            return resp.choices[0].message.content
        except Exception:
            print("Error occurred while summarizing")
            logger.exception("Summarizer failed; keeping old summary")
            return previous_summary


class SummarizingMemory:
    """
    Keeps a short-term buffer of raw messages plus a long-term summary.
    When the buffer fills, it is rolled into the summary via the Summarizer.
    """

    def __init__(
        self,
        actor,
        short_term_window: int = 15,
    ):
        self._buffer = MessageBuffer(maxlen=short_term_window)
        self._summary: str = ""
        self._all_messages: List[Message] = []
        self.actor = actor

    def add(self, msg: Message) -> None:
        self._all_messages.append(msg)
        self._buffer.add(msg)
        if self._buffer.is_full():
            self._rollup()
            logger.debug(
                f"Memory roll-up complete for actor_id={getattr(self.actor, 'actor_id', '?')} at t={getattr(msg, 'timestamp', -1):.1f}",
            )

    def _rollup(self) -> None:
        batch = self._buffer.flush()
        try:
            self._summary = Summarizer.summarize(self._summary, batch)
        except Exception:
            # Summarizer already logs; leave buffer cleared so we don't re-roll infinitely
            logger.exception("Summarizer failed during memory roll-up")

    def parsing_buffer(self, buffer):
        """
        Parses the buffer and returns a string representation of its contents.
        """
        parsed_buffer = []
        for msg in buffer:
            if msg.comm_type == "tool_call":
                parsed_buffer.append(
                    f"Timestamp={int(msg.timestamp)}, Tool call: {msg.content}"
                )
            elif msg.comm_type == "tool_result":
                parsed_buffer.append(
                    f"Timestamp={int(msg.timestamp)}, Tool result: {msg.content}"
                )
            elif msg.comm_type == "reasoning":
                parsed_buffer.append(
                    f"Timestamp={int(msg.timestamp)}, Reasoning: {msg.content}"
                )
            elif msg.comm_type == "async":
                if (
                    msg.sender == self.actor.org_role
                    or msg.sender == f"{self.actor.actor_id}:{self.actor.org_role}"
                    or msg.sender == self.actor.actor_id
                ):
                    parsed_buffer.append(
                        f"Timestamp={int(msg.timestamp)}, You sent an email to {msg.recipient}: {msg.content}"
                    )
                else:
                    parsed_buffer.append(
                        f"Timestamp={int(msg.timestamp)}, Email from {msg.sender}: {msg.content}"
                    )
            elif msg.comm_type == "sync":
                if (
                    msg.sender == self.actor.org_role
                    or msg.sender == f"{self.actor.actor_id}:{self.actor.org_role}"
                    or msg.sender == self.actor.actor_id
                ):
                    sender = "you"
                else:
                    sender = msg.sender

                parsed_buffer.append(
                    f"Timestamp={int(msg.timestamp)}, During the meeting, {sender} said '{msg.content}'"
                )
            elif msg.comm_type == "agenda":
                parsed_buffer.append(
                    f"Timestamp={int(msg.timestamp)}, You entered a meeting with the following agenda:\n {msg.content}"
                )

            else:
                parsed_buffer.append(repr(msg))

        return "\n".join(parsed_buffer)

    def retrieve_context(self) -> str:
        """
        Returns the combined long-term summary plus any un-rolled short-term messages.
        """
        # If buffer happens to be full again, roll it up before retrieving
        if self._buffer.is_full():
            self._rollup()
            logger.debug("Memory roll-up performed during context retrieval")

        recent = self.parsing_buffer(self._buffer.peek_all())
        if not self._summary:
            return recent
        return f"###Old messages:\n{self._summary}\n\n###Recent messages:\n{recent}\n###End of the messages.\n"

    def get_summary(self) -> str:
        return self._summary

    def get_all_messages(self) -> List[Message]:
        return list(self._all_messages)

    def get_relevant(self, agenda: str) -> Optional[str]:
        ctx = self.retrieve_context()
        return RelevanceChecker.check(ctx, agenda)
