# lmkit/experiments/smiles_events.py
from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Optional, Dict

import numpy as np
from tokenizers import Tokenizer


@dataclass
class RingEvent:
    batch: int
    open_idx: int  # token index of opener (last token of the ID, e.g. digit or final digit of %nn)
    close_idx: int  # token index of closer (last token of the ID)
    pred_idx: int  # index where the closer is predicted (close_idx - 1)
    ring_id: str
    distance: int


@dataclass
class ParenEvent:
    batch: int
    open_idx: int
    close_idx: int
    pred_idx: int  # close_idx - 1
    depth: int
    distance: int


@dataclass
class Events:
    rings: List[RingEvent]
    parens: List[ParenEvent]


def _tok_str(tokenizer: Tokenizer, tid: int) -> str:
    # tokenizers.Tokenizer API
    return tokenizer.id_to_token(int(tid))


def _is_digit_token(ts: str) -> bool:
    # covers "0".."9" and also cases where tokenizer packs "10".."99" as one token
    return ts.isdigit()


def _read_ring_id(tokens_str: List[str], i: int) -> Tuple[Optional[str], int]:
    """
    Parse a ring ID starting at position i. Returns (ring_id, end_index).
    - Single-digit: "1".."9" (and possibly "0")
    - Multi-digit: "%10" may appear as ["%", "10"] OR ["%", "1", "0"]
    We normalize ring_id to the digits string, e.g. "1", "10", "23".
    The returned end_index is the index of the **last** token of the ring ID.
    """
    ts = tokens_str[i]
    if ts == "%":
        # Try compact digit token after '%'
        if i + 1 < len(tokens_str):
            nxt = tokens_str[i + 1]
            # e.g., ["%", "10"]
            if _is_digit_token(nxt) and len(nxt) >= 2:
                return nxt, i + 1
            # e.g., ["%", "1", "0"]
            if i + 2 < len(tokens_str):
                nxt2 = tokens_str[i + 2]
                if _is_digit_token(nxt) and _is_digit_token(nxt2):
                    return nxt + nxt2, i + 2
        return None, i
    # Single or packed digits without '%'
    if _is_digit_token(ts):
        return ts, i
    return None, i


def extract_events_for_batch(
    tokenizer: Tokenizer, ids_row: np.ndarray
) -> Tuple[List[RingEvent], List[ParenEvent]]:
    """
    ids_row: (T,) input ids including BOS/EOS/PAD as produced by lmkit.tools.data
    We operate on token **strings** to be tokenizer-agnostic.
    """
    toks = [_tok_str(tokenizer, int(t)) for t in ids_row]
    rings: List[RingEvent] = []
    parens: List[ParenEvent] = []

    # --- ring open/close tracking
    ring_stack: Dict[str, List[int]] = {}  # ring_id -> list of opener indices

    # --- parentheses stack
    paren_stack: List[int] = []

    for i, ts in enumerate(toks):
        # Parentheses
        if ts == "(":
            paren_stack.append(i)
        elif ts == ")":
            if paren_stack:
                j = paren_stack.pop()
                pred_idx = max(0, i - 1)
                parens.append(
                    ParenEvent(
                        batch=-1,
                        open_idx=j,
                        close_idx=i,
                        pred_idx=pred_idx,
                        depth=len(paren_stack),
                        distance=i - j,
                    )
                )

        # Rings
        rid, end_i = _read_ring_id(toks, i)
        if rid is not None:
            anchor = end_i  # we anchor the event on the **last** digit token
            stack = ring_stack.setdefault(rid, [])
            if stack:
                j = stack.pop()
                pred_idx = max(0, anchor - 1)
                rings.append(
                    RingEvent(
                        batch=-1,
                        open_idx=j,
                        close_idx=anchor,
                        pred_idx=pred_idx,
                        ring_id=rid,
                        distance=anchor - j,
                    )
                )
            else:
                stack.append(anchor)

    # discard unterminated opens
    return rings, parens


def extract_events(tokenizer: Tokenizer, batch_ids: np.ndarray) -> Events:
    rings_all: List[RingEvent] = []
    parens_all: List[ParenEvent] = []
    for b in range(batch_ids.shape[0]):
        rings, parens = extract_events_for_batch(tokenizer, batch_ids[b])
        for ev in rings:
            ev.batch = b
        for ev in parens:
            ev.batch = b
        rings_all.extend(rings)
        parens_all.extend(parens)
    return Events(rings=rings_all, parens=parens_all)
