from typing import (
    Dict,
    Iterable,
    List,
    Optional,
    Sequence,
    Tuple,
    Any,
    Set,
    DefaultDict,
    Union,
)
from collections import defaultdict
from pathlib import Path
import os
import json
import json5
import logging

import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from matplotlib.lines import Line2D
from matplotlib.ticker import MultipleLocator,  AutoMinorLocator

_SYSTEM_ACTOR_ID = -1  # never plot this lane
_log = logging.getLogger(__name__)

# ───────────────────────── Helpers ─────────────────────────


def _coerce_actor_id(v: Any) -> Optional[int]:
    """Accept: 5, '5', '5:Role', 'Role:5' -> 5. Else -> None."""
    if isinstance(v, int):
        return v
    if isinstance(v, float) and v.is_integer():
        return int(v)
    if isinstance(v, str):
        s = v.strip()
        if ":" in s:
            left, right = s.split(":", 1)
            for token in (left, right):
                try:
                    return int(token)
                except Exception:
                    pass
        try:
            return int(s)
        except Exception:
            return None
    return None


def _extract_label_pair(token: Any) -> Optional[Tuple[int, str]]:
    """
    Try to extract (id, role) from strings like 'LegalTeam:5' or '5:LegalTeam'.
    Returns None if it can't.
    """
    if not isinstance(token, str) or ":" not in token:
        return None
    a, b = token.split(":", 1)
    for left, right in ((a, b), (b, a)):
        try:
            aid = int(left)
            role = right.strip()
            if role:
                return aid, role
        except Exception:
            pass
    return None


def _normalize_details(det_raw: Any) -> Dict[str, Any]:
    """Ensure Details is a dict; parse json-ish strings; else wrap as text."""
    if isinstance(det_raw, dict):
        return det_raw
    if isinstance(det_raw, str):
        s = det_raw.strip()
        if s[:1] in "{[" and s[-1:] in "}]":
            if json5 is not None:
                try:
                    return json5.loads(s)
                except Exception:
                    pass
            try:
                return json.loads(s)
            except Exception:
                return {"text": det_raw}
        return {"text": det_raw}
    return {"text": str(det_raw)}


def _collect_actor_ids_from_events(
    events: Sequence[Tuple[float, str, Any]],
) -> Set[int]:
    ids: Set[int] = set()
    for _, _etype, det_raw in events:
        det = _normalize_details(det_raw)
        for key in ("actor_id", "from", "to"):
            aid = _coerce_actor_id(det.get(key))
            if aid is not None and aid != _SYSTEM_ACTOR_ID:
                ids.add(aid)
        recips = det.get("recipients") or []
        if isinstance(recips, (list, tuple)):
            for r in recips:
                rid = _coerce_actor_id(r)
                if rid is not None and rid != _SYSTEM_ACTOR_ID:
                    ids.add(rid)
    return ids


def _collect_actor_ids_from_lifespans(
    lifespans: Optional[Dict[Any, Tuple[float, float]]],
) -> Set[int]:
    out: Set[int] = set()
    if not lifespans:
        return out
    for k in lifespans.keys():
        aid = _coerce_actor_id(k)
        if aid is not None and aid != _SYSTEM_ACTOR_ID:
            out.add(aid)
    return out


def _infer_labels_from_events(
    events: Sequence[Tuple[float, str, Any]],
) -> Dict[int, str]:
    """
    Try to infer id→role from any string tokens like 'Role:5' present in
    actor_id/from/to/recipients fields inside event Details.
    """
    inferred: Dict[int, str] = {}
    for _, _etype, det_raw in events:
        det = _normalize_details(det_raw)
        for key in ("actor_id", "from", "to"):
            val = det.get(key)
            pair = _extract_label_pair(val) if isinstance(val, str) else None
            if pair:
                aid, role = pair
                if aid != _SYSTEM_ACTOR_ID:
                    inferred.setdefault(aid, role)
        recips = det.get("recipients") or []
        if isinstance(recips, (list, tuple)):
            for r in recips:
                pair = _extract_label_pair(r) if isinstance(r, str) else None
                if pair:
                    aid, role = pair
                    if aid != _SYSTEM_ACTOR_ID:
                        inferred.setdefault(aid, role)
    return inferred


def _lane_labels(
    actor_ids: Sequence[int], id_to_label: Optional[Dict[int, str]] = None
) -> List[str]:
    labels = []
    for aid in actor_ids:
        role = (id_to_label.get(aid) if id_to_label else None) or "Actor"
        labels.append(f"{role}:{aid}")
    return labels


def _apply_time_axis_style(ax, sim_end: float, tick_step: Optional[float]) -> None:
    """
    Add integer timestep ticks + minor ticks and light grid.
    Always use 1.0 step when tick_step is None.
    """
    end = max(float(sim_end), 1.0)

    # Enforce a major tick every single timestep by default
    if tick_step is None:
        tick_step = 2.0

    ax.set_xlim(0.0, end)
    ax.xaxis.set_major_locator(MultipleLocator(tick_step))
    ax.xaxis.set_minor_locator(AutoMinorLocator(2))
    ax.grid(True, which="major", axis="x", linewidth=0.6, alpha=0.35)
    ax.grid(True, which="minor", axis="x", linewidth=0.4, alpha=0.15)

    # rotate + align x tick labels
    ax.tick_params(axis="x", labelrotation=90)
    for lbl in ax.get_xticklabels():
        lbl.set_verticalalignment("top")
        lbl.set_horizontalalignment("center")


# ───────────────────────── Plot (continuous time) ─────────────────────────


def plot_simulation_timeline_from_events(
    events: Sequence[Tuple[float, str, Any]],
    sim_end: float,
    lifespans: Optional[Dict[Any, Tuple[float, float]]] = None,
    originals: Optional[Iterable[int]] = None,
    id_to_label: Optional[Dict[int, str]] = None,
    *,
    figsize: Optional[Tuple[int, int]] = None,
    tick_step: Optional[float] = 1.0,
):
    """
    Continuous-time timeline with:
      - Lifespan bars (start/end ticks + markers) BEHIND arrows,
      - Reasoning/tool self-loops (with de-overlapped labels),
      - Async (green) & Meetings (purple),
      - Study start/completion marks.
    """
    originals = set(originals or [])
    lifespans = lifespans or {}

    # Build labels: start with env-supplied labels, then enrich from the log.
    label_map = dict(id_to_label or {})

    # 1) Prefer explicit actor_type from the log (Details.actor_type)
    for aid, role in _labels_from_events_actor_type(events).items():
        label_map.setdefault(aid, role)

    # 2) (Optional) also support "Role:ID" tokens found in events
    for aid, role in _infer_labels_from_events(events).items():
        label_map.setdefault(aid, role)

    # Lanes ONLY for labeled, positive IDs (drops 0 / system / unlabeled)
    lanes_list = sorted(
        aid
        for aid, role in label_map.items()
        if isinstance(aid, int) and aid > 0 and aid != _SYSTEM_ACTOR_ID and role
    )
    lanes = {aid: idx for idx, aid in enumerate(lanes_list)}
    nlanes = len(lanes_list)

    # Figure / axes — dynamic height so ~4–5 lanes fill the screen when zoomed
    if figsize is None:
        width = 100.0  # keep it very wide by default (was 100)
        # ~0.9" per lane, but clamp to [3, 6] so it never gets too tall
        height = max(1.5, min(3.0, 0.45 * max(4, nlanes)))
        figsize = (width, height)

    fig, ax = plt.subplots(figsize=figsize)
    ax.set_xlabel("Simulation time (hours)")
    ax.set_yticks(list(lanes.values()))
    ax.set_yticklabels(_lane_labels(lanes_list, label_map))

    # Apply time ticks/grid
    _apply_time_axis_style(ax, sim_end, tick_step)

    # Vertical extent
    ax.set_ylim(-0.5, nlanes - 0.5 if nlanes > 0 else 0.5)

    # Lifespans (bars behind; ticks & markers above)
    for k, (t0, t1) in lifespans.items():
        aid = _coerce_actor_id(k)
        if aid is None or aid == _SYSTEM_ACTOR_ID or aid not in lanes:
            continue
        y_c = lanes[aid]
        y = y_c - 0.3
        x0, x1 = float(t0), float(t1)
        width = max(0.0, (x1 - x0))
        face = (0.7, 0.7, 0.7, 1.0) if aid in originals else (0.85, 0.85, 0.85, 1.0)

        # bar (behind everything)
        ax.add_patch(
            Rectangle(
                (x0, y),
                width,
                0.6,
                facecolor=face,
                edgecolor="black",
                linewidth=0.7,
                zorder=0,
            )
        )

        # vertical ticks (above bar)
        ax.vlines([x0], y_c - 0.28, y_c + 0.28, colors="black", linewidth=1.0, zorder=3)
        ax.vlines([x1], y_c - 0.28, y_c + 0.28, colors="black", linewidth=1.0, zorder=3)

        # edge diamonds (always visible)
        ax.scatter(
            [x0],
            [y_c],
            marker="D",
            s=30,
            color="white",
            edgecolors="black",
            linewidths=0.8,
            zorder=4,
        )
        ax.scatter(
            [x1],
            [y_c],
            marker="D",
            s=30,
            color="white",
            edgecolors="black",
            linewidths=0.8,
            zorder=4,
        )

        # spawn/delete flags (triangles above lane)
        spawned = (aid not in originals) and (t0 > 0.0 + 1e-9)
        deleted = t1 < sim_end - 1e-9
        if spawned:
            ax.scatter(
                [x0],
                [y_c + 0.34],
                marker="^",
                s=120,
                color="#2ca02c",
                edgecolor="black",
                linewidths=0.8,
                zorder=7,
                clip_on=False,
            )
        if deleted:
            ax.scatter(
                [x1],
                [y_c + 0.34],
                marker="v",
                s=120,
                color="#d62728",
                edgecolor="black",
                linewidths=0.8,
                zorder=7,
                clip_on=False,
            )

    # ── Self-loop spacing controls (de-overlap labels) ─────────────────────
    TOOL_BIN = 0.5  # hours per time bin (continuous time)
    TOOL_STACK_STEP = 0.16  # vertical step between stacked arcs
    TOOL_LABEL_DY = 0.14  # base vertical gap from arc to label
    TOOL_LABEL_DY_STEP = 0.10
    TOOL_LABEL_DX = 0.00
    TOOL_LABEL_DX_STEP = 0.05

    lane_bin_counts: DefaultDict[int, DefaultDict[int, int]] = defaultdict(
        lambda: defaultdict(int)
    )

    def _stack_info(aid: int, t: float, upwards: bool) -> Tuple[float, int]:
        """Return (vertical_offset, index_in_bin) and increment the bin counter."""
        if aid not in lanes:
            return 0.0, 0
        b = int(float(t) // TOOL_BIN)
        lane = lanes[aid]
        idx = lane_bin_counts[lane][b]
        lane_bin_counts[lane][b] += 1
        base = +0.6 if upwards else -0.6
        return base + (
            TOOL_STACK_STEP * idx if upwards else -TOOL_STACK_STEP * idx
        ), idx

    # Draw helpers
    def mark_point(
        t: float, aid: Optional[int], marker: str, size: int, color: str, z: int = 3
    ):
        if aid is None or aid == _SYSTEM_ACTOR_ID or aid not in lanes:
            return
        y = lanes[aid]
        ax.scatter([t], [y], marker=marker, s=size, color=color, zorder=z)

    def draw_selfloop(
        t: float,
        aid: Optional[int],
        up: bool,
        color: str,
        z: int,
        label: Optional[str] = None,
    ):
        if aid is None or aid == _SYSTEM_ACTOR_ID or aid not in lanes:
            return
        y = lanes[aid]
        y_off, idx = _stack_info(aid, t, up)
        y_from = y + y_off
        y_to = y

        # arc
        ax.annotate(
            "",
            xy=(t, y_to),
            xytext=(t, y_from),
            arrowprops=dict(
                arrowstyle="->",
                color=color,
                lw=2.0,
                connectionstyle=f"arc3,rad={0.7 if up else -0.7}",
            ),
            zorder=z,
        )

        # label (nudged based on stack index)
        if label:
            label_x = t + (TOOL_LABEL_DX + idx * TOOL_LABEL_DX_STEP)
            label_y = y_from + (TOOL_LABEL_DY + idx * TOOL_LABEL_DY_STEP) * (
                1 if up else -1
            )
            ax.text(
                label_x,
                label_y,
                str(label),
                fontsize=8,
                color=color,
                ha="left",
                va="bottom" if up else "top",
                zorder=z + 1,
                clip_on=False,
                bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="none", alpha=0.8),
            )

    # Study markers: include "designed" as a start marker (▶), not just "started"
    START_MARKERS = {"study started", "study designed"}
    END_MARKERS = {"study completed"}  # keep end as ■; add more if you want

    # Render events (only real meetings; no dashed "intent")
    for t, etype, det_raw in events:
        et = (etype or "").strip()
        det = _normalize_details(det_raw)
        aid = _coerce_actor_id(det.get("actor_id"))
        _from = _coerce_actor_id(det.get("from"))
        _to = _coerce_actor_id(det.get("to"))
        recips = det.get("recipients") or []
        et_lower = et.lower()

        if et_lower == "reasoning":
            draw_selfloop(t, aid, up=True, color="red", z=4)
            continue
        if et_lower == "tool call":
            draw_selfloop(
                t, aid, up=False, color="blue", z=5, label=str(det.get("tool") or "")
            )
            continue
        if et_lower == "tool result":
            draw_selfloop(t, aid, up=False, color="blue", z=4)
            continue

        if et_lower in {"sync message", "sync"}:
            if (
                _from is not None
                and _to is not None
                and _from != _SYSTEM_ACTOR_ID
                and _to != _SYSTEM_ACTOR_ID
                and _from in lanes
                and _to in lanes
            ):
                y_from, y_to = lanes[_from], lanes[_to]
                rad = 0.3 if _from < _to else -0.3
                ax.annotate(
                    "",
                    xy=(t, y_to),
                    xytext=(t, y_from),
                    arrowprops=dict(
                        arrowstyle="->",
                        lw=2.0,
                        color="purple",
                        connectionstyle=f"arc3,rad={rad}",
                    ),
                    zorder=3,
                )
            continue

        if et_lower in {"communicatingasync", "async"}:
            if aid is None or aid == _SYSTEM_ACTOR_ID:
                continue
            if isinstance(recips, (list, tuple)):
                for r in recips:
                    rid = _coerce_actor_id(r)
                    if (
                        rid is None
                        or rid == _SYSTEM_ACTOR_ID
                        or rid not in lanes
                        or aid not in lanes
                    ):
                        continue
                    y_from, y_to = lanes[aid], lanes[rid]
                    rad = 0.3 if aid < rid else -0.3
                    ax.annotate(
                        "",
                        xy=(t, y_to),
                        xytext=(t, y_from),
                        arrowprops=dict(
                            arrowstyle="->",
                            lw=2.0,
                            color="green",
                            connectionstyle=f"arc3,rad={rad}",
                        ),
                        zorder=3,
                    )
            else:
                mark_point(t, aid, marker="o", size=25, color="green", z=3)
            continue

        # Study marks: now includes "Study designed" as start
        if et_lower in START_MARKERS | END_MARKERS:
            study_id = det.get("study_id") or ""
            ax.axvline(
                t,
                ymin=0.02,
                ymax=0.98,
                linewidth=0.6,
                color="gray",
                alpha=0.6,
                zorder=1,
            )
            is_start = et_lower in START_MARKERS
            prefix = "▶ " if is_start else "■ "
            ax.text(
                t,
                (nlanes - 0.5) if nlanes else 0.0,
                prefix + str(study_id),
                rotation=90,
                fontsize=7,
                va="top",
                ha="left",
                color="gray",
            )
            continue

        if et_lower == "patient hire":
            continue

    # Title / legend / layout
    ax.set_title("Simulation timeline (continuous)")
    if nlanes == 0:
        ax.set_yticks([])
        ax.set_ylim(-0.5, 0.5)

    legend_items = [
        Line2D([0], [0], color="green", lw=2, label="Async communication"),
        Line2D([0], [0], color="purple", lw=2, label="Meeting (sync)"),
        Line2D([0], [0], color="red", lw=2, label="Reasoning (self-loop)"),
        Line2D([0], [0], color="blue", lw=2, label="Tool call/result (self-loop)"),
        Line2D([0], [0], color="gray", lw=1, label="Study start (▶) / completion (■)"),
        Line2D([0], [0], color="black", lw=2, label="Actor lifespan (bar + markers)"),
    ]
    ax.legend(
        handles=legend_items,
        loc="center left",
        bbox_to_anchor=(1.02, 0.5),
        borderaxespad=0.0,
        frameon=True,
    )

    fig.subplots_adjust(right=0.78, bottom=0.12)
    return fig, ax


# ───────────────────────── New helpers for log→plot ─────────────────────────


def _events_from_log(log_path: Union[str, Path]) -> list[tuple[float, str, dict]]:
    """
    Parse events_epXX.log (JSONL *with logging prefixes*) -> [(t, type, Details), ...].
    """
    path = str(log_path)
    events: list[tuple[float, str, dict]] = []
    if not os.path.exists(path):
        _log.debug("events log not found: %s", path)
        return events

    with open(path, "r", encoding="utf-8") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            brace = line.find("{")
            if brace < 0:
                continue
            try:
                obj = json.loads(line[brace:])
                t = float(obj.get("t", 0.0))
                etype = (obj.get("type") or "").strip()
                details = obj.get("Details") or {}
                events.append((t, etype, details))
            except Exception:
                continue
    return events


def _lifespans_from_env_and_info(env, info: dict) -> dict[int, tuple[float, float]]:
    """
    Build {actor_id: (t0, t1)} using env.simulation.{actor_created, actor_deleted}
    with a fallback to memory keys if needed.
    """
    sim_end = float(info.get("metrics", {}).get("sim_time", 0.0))
    lifespans: dict[int, tuple[float, float]] = {}

    try:
        created = getattr(env.simulation, "actor_created", {}) or {}
        deleted = getattr(env.simulation, "actor_deleted", {}) or {}
        ids = set(map(int, created.keys())) | set(map(int, deleted.keys()))
        for aid in ids:
            if aid <= 0:
                continue
            t0 = float(created.get(aid, 0.0))
            t1 = float(deleted.get(aid, sim_end))
            lifespans[int(aid)] = (t0, t1)
    except Exception:
        pass

    if not lifespans:
        mem = info.get("full_memory_state", {}) or {}
        for k in mem.keys():
            try:
                aid = int(k)
            except Exception:
                continue
            lifespans[aid] = (0.0, sim_end)

    return lifespans


def _labels_from_env(env) -> dict[int, str]:
    """Infer {actor_id: role} from env.simulation.actors if available."""
    mapping: dict[int, str] = {}
    try:
        for a in getattr(env.simulation, "actors", []) or []:
            aid = getattr(a, "actor_id", None)
            role = getattr(a, "org_role", None)
            if isinstance(aid, int) and role:
                mapping[aid] = str(role)
    except Exception:
        pass
    return mapping


def _labels_from_events_actor_type(
    events: Sequence[Tuple[float, str, Any]],
) -> Dict[int, str]:
    """Infer {actor_id -> role} from Details.actor_type in the log."""
    mapping: Dict[int, str] = {}
    for _, _etype, det_raw in events:
        det = _normalize_details(det_raw)
        aid = _coerce_actor_id(det.get("actor_id"))
        role = det.get("actor_type")
        if isinstance(aid, int) and aid > 0 and isinstance(role, str) and role:
            mapping.setdefault(aid, role)
    return mapping


def save_timeline_from_episode_logs(
    output_dir: Union[str, Path],
    episode: int,
    env,
    info: dict,
    *,
    figsize: Optional[tuple[int, int]] = None,
    tick_step: Optional[float] = None,
    filename: Optional[str] = None,
    also_pdf: bool = True,
) -> Optional[Path]:
    """
    High-level one-call entrypoint used by run_policy.py:
      - reads events_epXX.log,
      - derives lifespans + labels from env/info,
      - renders timeline next to logs (PNG and, optionally, PDF).

    Returns path to the saved PNG, or None if no events were found.
    """
    out_dir = Path(output_dir)
    events_path = out_dir / f"events_ep{episode:02d}.log"
    events = _events_from_log(events_path)
    if not events:
        _log.info("No events parsed from %s; skipping timeline.", events_path)
        return None

    sim_end = float(info.get("metrics", {}).get("sim_time", 0.0))
    lifespans = _lifespans_from_env_and_info(env, info)
    id_to_label = _labels_from_env(env)
    originals = {aid for aid, (t0, _t1) in lifespans.items() if float(t0) <= 1e-9}

    fig, ax = plot_simulation_timeline_from_events(
        events=events,
        sim_end=sim_end,
        lifespans=lifespans,
        originals=originals,
        id_to_label=id_to_label,
        figsize=figsize,
        tick_step=tick_step,
    )

    png_name = filename or f"timeline_ep{episode:02d}.png"
    out_png = out_dir / png_name
    fig.savefig(out_png, dpi=150, bbox_inches="tight")

    if also_pdf:
        out_pdf = out_dir / f"{Path(png_name).stem}.pdf"
        # vector export is great for zooming+printing
        fig.savefig(out_pdf, bbox_inches="tight")

    plt.close(fig)
    _log.info("Saved timeline: %s%s", out_png, f" (and {out_pdf})" if also_pdf else "")
    return out_png


__all__ = [
    "plot_simulation_timeline_from_events",
    "save_timeline_from_episode_logs",
]
