from functools import wraps
from typing import Any, Callable, Iterable, Optional
import json
import logging
import simpy

_events_logger = logging.getLogger("monitoring.events")


def trace(
    env: simpy.Environment,
    callback: Callable[[float, int, int, simpy.events.Event], None],
):
    """
    Patch env.step() so that before each step we call `callback(t, prio, eid, event)`
    for the *next* event in the queue. We skip SimPy's own internal housekeeping
    events to keep your log readable.
    """

    def get_wrapper(env_step, callback):
        @wraps(env_step)
        def tracing_step():
            if len(env._queue):
                t, prio, eid, event = env._queue[0]
                # Ignore SimPy internals by default
                if not isinstance(
                    event,
                    (
                        simpy.events.Timeout,
                        simpy.events.Initialize,
                        simpy.events.Process,
                    ),
                ):
                    callback(t, prio, eid, event)
            return env_step()

        return tracing_step

    env.step = get_wrapper(env.step, callback)


def make_event_monitor(
    store: list[tuple[float, int, Any]],
    logger: Optional[logging.Logger] = None,
):
    """
    Returns a callback suitable for `trace()` that appends tuples
    (t, eid, event.value) to `store` **only** when event.value is not None.
    """

    def _monitor(t: float, prio: int, eid: int, event: simpy.events.Event):
        val = getattr(event, "value", None)
        if val is not None:
            store.append((t, eid, val))
            # Also write a JSON line to events.log
            log = logger or _events_logger
            try:
                if isinstance(val, dict):
                    rec = {
                        "t": float(t),
                        "eid": int(eid),
                        "type": val.get("type"),
                        "Details": val.get("Details"),
                    }
                else:
                    rec = {"t": float(t), "eid": int(eid), "value": repr(val)}
                log.info(json.dumps(rec, default=str))
            except Exception:
                # Never break the sim if logging fails
                log.debug("Failed to serialize event for logging", exc_info=True)

    return _monitor


def parse_monitored_event(events: Iterable[tuple[float, int, Any]]) -> dict[str, Any]:
    """
    Crunch common metrics from the events you’ve recorded.
    Robust to missing keys and heterogeneous event payloads.
    """
    hired = 0
    side_effect_pairs = set()
    tool_calls = 0
    tool_results = 0
    async_msgs = 0
    sync_msgs = 0
    reasoning = 0
    application_reviews = 0
    studies_started = set()
    studies_completed = set()
    single_arm_studies_completed = set()
    comparative_studies_completed = set()
    studies_results = set()
    single_arm_studies_started = set()
    comparative_studies_started = set()
    phase_II_B = False
    single_arm_studies_approved = set()
    comparative_studies_approved = set()
    single_arm_studies_analysed = set()
    comparative_studies_analysed = set()
    phase_III = False
    trial_interrupted = False

    # Keep a small cache of "worked" events to avoid brittle re-parsing later
    # WORKED_TYPES = {"Reasoning", "CommunicatingAsync", "CommunicatingSync"}
    worked_actor_types: list[str] = []

    # Also compute a safe default total time (last time in the queue) in case
    # EpisodeCompleted hasn't been emitted yet this step.
    last_t = None
    events = [
        e
        for e in events
        if isinstance(e, tuple)
        and len(e) == 3
        and isinstance(e[2], dict)
        and "type" in e[2]
        and "Details" in e[2]
    ]

    for e in events:
        try:
            t, _eid, payload = e
            last_t = t if (last_t is None or t > last_t) else last_t
            if not isinstance(payload, dict):
                continue
            etype = str(payload.get("type", "")).strip()
            details = payload.get("Details", {}) or {}
        except Exception:
            continue

        if etype == "Patient hire":
            hired += 1

        elif etype == "Side effect":
            pid = details.get("patient_id")
            se = details.get("side_effect")
            if pid is not None and se is not None:
                side_effect_pairs.add((pid, se))

        elif etype == "Tool call":
            tool_calls += 1

        elif etype == "Tool result":
            tool_results += 1

        elif etype == "CommunicatingAsync":
            async_msgs += 1

        elif etype == "CommunicatingSync":
            sync_msgs += 1

        elif etype == "Reasoning":
            reasoning += 1

        elif etype == "Study designed":
            sid = details.get("study_id")
            if sid:
                studies_started.add(sid)
                if details.get("study_type") == "SingleArmStudy":
                    single_arm_studies_started.add(sid)
                elif details.get("study_type") == "ComparativeRandomisedStudy":
                    comparative_studies_started.add(sid)

        elif etype == "Study approved":
            sid = details.get("study_id")
            if sid:
                if details.get("study_type") == "SingleArmStudy":
                    single_arm_studies_approved.add(sid)
                elif details.get("study_type") == "ComparativeRandomisedStudy":
                    comparative_studies_approved.add(sid)

        elif etype == "Study analysed":
            sid = details.get("study_id")
            if sid:
                if details.get("study_type") == "SingleArmStudy":
                    single_arm_studies_analysed.add(sid)
                elif details.get("study_type") == "ComparativeRandomisedStudy":
                    comparative_studies_analysed.add(sid)

        elif etype == "Study completed":
            sid = details.get("study_id")
            if sid:
                studies_completed.add(sid)
                if details.get("study_type") == "SingleArmStudy":
                    single_arm_studies_completed.add(sid)
                elif details.get("study_type") == "ComparativeRandomisedStudy":
                    comparative_studies_completed.add(sid)

        elif etype == "Study results":
            sid = details.get("study_id")
            if sid:
                studies_results.add(sid)
        elif etype == "Application review":
            application_reviews += 1
            if details.get("approved", False):
                phase_III = True

        elif etype == "Phase start":
            phase = details.get("phase")
            if phase == "B":
                phase_II_B = True

        elif etype == "Trial interrupted":
            trial_interrupted = True

    # Compute completion scores
    if trial_interrupted:
        # Completion Score when the program has been interrupted
        criteria = [
            len(single_arm_studies_started) > 0,
            len(single_arm_studies_approved) > 0,
            len(single_arm_studies_completed) > 0,
            len(single_arm_studies_analysed) > 0,
        ]
        completion_score = sum(criteria) / len(criteria)
    else:
        # Completion Score when the program has not been interrupted

        criteria = [
            len(single_arm_studies_started) > 0,
            len(single_arm_studies_approved) > 0,
            len(single_arm_studies_completed) > 0,
            len(single_arm_studies_analysed) > 0,
            phase_II_B,
            len(comparative_studies_started) > 0,
            len(comparative_studies_approved) > 0,
            len(comparative_studies_completed) > 0,
            len(comparative_studies_analysed) > 0,
            application_reviews > 0,
            phase_III,
        ]

        completion_score = sum(criteria) / len(criteria)

    # Aggregate worked time per actor type safely
    worked_per_type: dict[str, int] = {}
    for at in worked_actor_types:
        worked_per_type[at] = worked_per_type.get(at, 0) + 1

    return {
        "completion_score": completion_score,
        "nb_hired_patients": hired,
        "nb_side_effect_events": len(side_effect_pairs),
        "nb_applications": application_reviews,
        "tool_calls": tool_calls,
        "tool_results": tool_results,
        "async_messages": async_msgs,
        "sync_messages": sync_msgs,
        "reasoning": reasoning,
        "studies_started": len(studies_started),
        "single_arm_studies_started": len(single_arm_studies_started),
        "comparative_studies_started": len(comparative_studies_started),
        "single_arm_studies_approved": len(single_arm_studies_approved),
        "comparative_studies_approved": len(comparative_studies_approved),
        "single_arm_studies_analysed": len(single_arm_studies_analysed),
        "comparative_studies_analysed": len(comparative_studies_analysed),
        "phase_II_B_started": phase_II_B,
        "phase_III_started": phase_III,
        "applications_reviewed": application_reviews,
        "trial_interrupted": trial_interrupted,
        "studies_completed": len(studies_completed),
        "single_arm_studies_completed": len(single_arm_studies_completed),
        "comparative_studies_completed": len(comparative_studies_completed),
        "studies_results": len(studies_results),
        "total_events": sum(1 for _ in events),
        # Keep this key name for backward-compat:
        "Total_worked_time": sum(
            1
            for event in events
            if event[2]["type"]
            in ["Reasoning", "CommunicatingAsync", "CommunicatingSync"]
        ),
        "Total_worked time per agent type": {
            a_type: sum(
                1
                for event in events
                if (
                    "actor_type" in event[2]["Details"]
                    and event[2]["Details"]["actor_type"] == a_type
                    and event[2]["type"]
                    in ["Reasoning", "CommunicatingAsync", "CommunicatingSync"]
                )
            )
            for a_type in list(
                set(
                    [
                        event[2]["Details"]["actor_type"]
                        for event in events
                        if event[2]["type"]
                        in ["Reasoning", "CommunicatingAsync", "CommunicatingSync"]
                    ]
                )
            )
        },
    }
