# lmkit/sparse/fragment_mapper.py
from __future__ import annotations
import re
from typing import Any, Dict, List, Tuple, Optional

from rdkit import Chem
from rdkit.Chem.rdmolfiles import MolFragmentToSmiles

# ──────────────────────────────────────────────────────────────────────────────
#                               TOKENIZATION (mapped SMILES)
# ──────────────────────────────────────────────────────────────────────────────

_TOKEN_REGEX = re.compile(
    r"""
    \[[^\[\]]+\] |        # bracket atom (includes :mapnum)
    Br|Cl           |     # two-letter atoms (harmless in mapped strings)
    \%\d{2}         |     # ring closure like %10
    \d              |     # ring closure 1..9
    =|#|-|:|\/|\\   |     # bond types / stereobonds
    \(|\)           |     # branches
    \.              |     # dot (disconnected components)
    @@?|~|\?|>|<|\*|\$|   # misc SMILES tokens that sometimes appear
    [A-Za-z]              # single-letter atoms (rare in mapped case)
    """,
    re.X,
)


def tokenize_with_spans(s: str) -> Tuple[List[str], List[Tuple[int, int]]]:
    tokens, spans = [], []
    pos = 0
    for m in _TOKEN_REGEX.finditer(s):
        if m.start() > pos:
            tokens.append(s[pos : m.start()])
            spans.append((pos, m.start()))
        tokens.append(m.group(0))
        spans.append((m.start(), m.end()))
        pos = m.end()
    if pos < len(s):
        tokens.append(s[pos:])
        spans.append((pos, len(s)))
    return tokens, spans


def _is_atom_token(tok: str) -> bool:
    return tok.startswith("[") and tok.endswith("]")


def _parens_opened_in_segment(
    tokens: List[str], start_tok_idx: int, end_tok_idx: int
) -> int:
    bal = 0
    for k in range(start_tok_idx, end_tok_idx + 1):
        t = tokens[k]
        if t == "(":
            bal += 1
        elif t == ")":
            bal = max(0, bal - 1)
    return bal


def _end_char_with_balanced_parens(
    tokens: List[str],
    spans: List[Tuple[int, int]],
    seg_start_tok: int,
    end_tok_idx: int,
) -> int:
    i = end_tok_idx + 1
    last = end_tok_idx
    while i < len(tokens):
        t = tokens[i]
        if _is_atom_token(t) or t == "." or t == "(":
            break
        if t == ")":
            break
        last = i
        i += 1
    need_rparens = _parens_opened_in_segment(tokens, seg_start_tok, end_tok_idx)
    while i < len(tokens) and need_rparens > 0 and tokens[i] == ")":
        last = i
        need_rparens -= 1
        i += 1
    return spans[last][1]


def _substrings_for_match_mapped(
    atom_token_idxs: List[int],
    tokens: List[str],
    spans: List[Tuple[int, int]],
    s: str,
) -> List[str]:
    if not atom_token_idxs:
        return []
    atom_set = set(atom_token_idxs)
    atom_token_idxs_sorted = sorted(atom_token_idxs)

    segments: List[str] = []
    seg_start_tok = atom_token_idxs_sorted[0]
    prev_tok = seg_start_tok

    for idx in atom_token_idxs_sorted[1:]:
        other_atom_between = any(
            _is_atom_token(tokens[k]) and (k not in atom_set)
            for k in range(prev_tok + 1, idx)
        )
        if other_atom_between:
            start_char = spans[seg_start_tok][0]
            end_char = _end_char_with_balanced_parens(
                tokens, spans, seg_start_tok, prev_tok
            )
            segments.append(s[start_char:end_char])
            seg_start_tok = idx
        prev_tok = idx

    start_char = spans[seg_start_tok][0]
    end_char = _end_char_with_balanced_parens(tokens, spans, seg_start_tok, prev_tok)
    segments.append(s[start_char:end_char])
    return segments


# ──────────────────────────────────────────────────────────────────────────────
#                         MAP CONSTRUCTION (atom-mapped SMILES)
# ──────────────────────────────────────────────────────────────────────────────


def atom_mapped_smiles_and_maps(mol: Chem.Mol):
    m = Chem.Mol(mol)
    for a in m.GetAtoms():
        a.SetAtomMapNum(a.GetIdx() + 1)  # avoid :0 (RDKit suppresses 0)
    mapped = Chem.MolToSmiles(m, canonical=False, isomericSmiles=True)

    tokens, spans = tokenize_with_spans(mapped)
    atom_to_token: Dict[int, int] = {}
    atom_to_span: Dict[int, Tuple[int, int]] = {}

    for i, tok in enumerate(tokens):
        if tok.startswith("[") and tok.endswith("]"):
            mm = re.search(r":(\d+)\]", tok)
            if mm:
                atom_idx_1based = int(mm.group(1))
                atom_idx_0based = atom_idx_1based - 1
                atom_to_token[atom_idx_0based] = i
                atom_to_span[atom_idx_0based] = spans[i]

    return mapped, tokens, spans, atom_to_token, atom_to_span


# ──────────────────────────────────────────────────────────────────────────────
#                 ORIGINAL-SMILES atom spans + segment → token indices
# ──────────────────────────────────────────────────────────────────────────────

_ORIG_ATOM_RE = re.compile(r"\[[^\[\]]+\]|Br|Cl|I|F|B|C|N|O|P|S|se|as|te|b|c|n|o|p|s")


def original_atom_spans(smiles: str) -> List[Tuple[int, int]]:
    spans: List[Tuple[int, int]] = []
    for m in _ORIG_ATOM_RE.finditer(smiles):
        spans.append((m.start(), m.end()))
    return spans


def _extend_right_in_text(s: str, seg_start: int, end_char: int) -> int:
    i = end_char
    last = end_char
    # include ring/bond markers
    while i < len(s):
        ch = s[i]
        if ch == "%" and i + 2 < len(s) and s[i + 1 : i + 3].isdigit():
            i += 3
            last = i
            continue
        if ch.isdigit():
            i += 1
            last = i
            continue
        if ch in "=#:/\\":
            i += 1
            last = i
            continue
        if ch in ".[(" or ch == "[" or ch.isalpha():
            break
        break
    # balance right parens
    opened = 0
    for ch in s[seg_start:last]:
        if ch == "(":
            opened += 1
        elif ch == ")":
            opened = max(0, opened - 1)
    while i < len(s) and opened > 0 and s[i] == ")":
        i += 1
        last = i
    return last if last > end_char else end_char


def _segments_for_match_in_original(
    match_atoms: Tuple[int, ...], atom_spans: List[Tuple[int, int]], s: str
) -> List[Tuple[Tuple[int, int], str]]:
    if not match_atoms:
        return []
    sorted_idx = sorted(match_atoms)
    segments: List[Tuple[Tuple[int, int], str]] = []

    seg_start_atom = sorted_idx[0]
    prev_atom = seg_start_atom

    for idx in sorted_idx[1:]:
        if idx != prev_atom + 1:
            s0, e0 = atom_spans[seg_start_atom][0], atom_spans[prev_atom][1]
            e0 = _extend_right_in_text(s, s0, e0)
            segments.append(((s0, e0), s[s0:e0]))
            seg_start_atom = idx
        prev_atom = idx

    s0, e0 = atom_spans[seg_start_atom][0], atom_spans[prev_atom][1]
    e0 = _extend_right_in_text(s, s0, e0)
    segments.append(((s0, e0), s[s0:e0]))
    return segments


# ──────────────────────────────────────────────────────────────────────────────
#                     LM tokenizer offsets → token index mapper
# ──────────────────────────────────────────────────────────────────────────────

try:
    from tokenizers import Tokenizer as _HFTokenizer
except Exception:
    _HFTokenizer = None  # type: ignore


def tokenize_text_with_offsets(
    tokenizer: "_HFTokenizer",
    text: str,
    *,
    add_special_tokens: bool = False,
) -> Dict[str, Any]:
    enc = tokenizer.encode(text, add_special_tokens=add_special_tokens)
    return {
        "ids": enc.ids,  # List[int]
        "tokens": enc.tokens,  # List[str] from encoding
        "offsets": enc.offsets,  # List[Tuple[int,int]]
    }


def token_indices_for_span(
    offsets: List[Tuple[int, int]], start: int, end: int
) -> List[int]:
    out: List[int] = []
    for i, (s, e) in enumerate(offsets):
        if e <= s:
            continue
        if e > start and s < end:
            out.append(i)
    return out


# ──────────────────────────────────────────────────────────────────────────────
#                                      MAIN
# ──────────────────────────────────────────────────────────────────────────────


def compile_smarts(smarts_by_name: Dict[str, str]) -> Dict[str, Chem.Mol]:
    queries = {}
    for name, smarts in smarts_by_name.items():
        q = Chem.MolFromSmarts(smarts)
        if q is None:
            raise ValueError(f"Invalid SMARTS for {name}: {smarts}")
        queries[name] = q
    return queries


def find_fragments_and_tokens(
    smiles: str,
    queries: Dict[str, Chem.Mol],
    *,
    tokenizer: Optional["_HFTokenizer"] = None,
    include_fragment_smiles: bool = True,
    add_special_tokens_for_offsets: bool = False,
) -> Dict[str, Any]:
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES: {smiles}")

    # A) mapped-SMILES view
    mapped, mapped_tokens, mapped_spans, atom_to_token, _ = atom_mapped_smiles_and_maps(
        mol
    )

    # B) original-SMILES tokenization + offsets
    token_info: Optional[Dict[str, Any]] = None
    if tokenizer is not None:
        token_info = tokenize_text_with_offsets(
            tokenizer, smiles, add_special_tokens=add_special_tokens_for_offsets
        )

    orig_atom_spans = original_atom_spans(smiles)

    out: Dict[str, Any] = {
        "input_smiles": smiles,
        "mapped_smiles": mapped,
        "mapped_tokens": [
            {"text": t, "span": s, "is_atom": _is_atom_token(t)}
            for t, s in zip(mapped_tokens, mapped_spans)
        ],
        "atom_to_token_mapped": atom_to_token,
        "tokenization_original": token_info,  # {ids,tokens,offsets} if tokenizer provided
        "fragments": {},
    }

    for name, q in queries.items():
        matches = list(mol.GetSubstructMatches(q, useChirality=True))
        if not matches:
            continue

        occs: List[Dict[str, Any]] = []
        for match_atoms in matches:
            # mapped view
            mapped_atom_toks = [
                atom_to_token[a] for a in match_atoms if a in atom_to_token
            ]
            mapped_substrings = _substrings_for_match_mapped(
                mapped_atom_toks, mapped_tokens, mapped_spans, mapped
            )

            # RDKit fragment smiles
            frag_smiles = None
            if include_fragment_smiles:
                try:
                    frag_smiles = MolFragmentToSmiles(
                        mol, atomsToUse=list(match_atoms), isomericSmiles=True
                    )
                except Exception:
                    frag_smiles = None

            # original view
            original_segments: List[Dict[str, Any]] = []
            if tokenizer is not None and orig_atom_spans:
                segs = _segments_for_match_in_original(
                    match_atoms, orig_atom_spans, smiles
                )
                ids = token_info["ids"] if token_info else []
                toks_from_enc = token_info["tokens"] if token_info else []
                offsets = token_info["offsets"] if token_info else []

                for (start, end), txt in segs:
                    idxs = token_indices_for_span(offsets, start, end)
                    seg_token_ids = [int(ids[i]) for i in idxs]
                    seg_tokens_from_encoding = [toks_from_enc[i] for i in idxs]

                    # decode via tokenizer.id_to_token to verify
                    seg_tokens_decoded = [
                        tokenizer.id_to_token(int(tid)) for tid in seg_token_ids
                    ]
                    matches_mask = [
                        a == b
                        for a, b in zip(seg_tokens_from_encoding, seg_tokens_decoded)
                    ]
                    decoded_match = all(matches_mask)
                    mismatch_at = [
                        idxs[k] for k, ok in enumerate(matches_mask) if not ok
                    ]

                    original_segments.append(
                        {
                            "char_span": (start, end),
                            "text": txt,
                            "token_indices": idxs,  # offsets-based indices (no specials)
                            "token_ids": seg_token_ids,
                            "token_texts_from_encoding": seg_tokens_from_encoding,
                            "token_texts_decoded": seg_tokens_decoded,
                            "decoded_match": decoded_match,
                            "mismatch_at": mismatch_at,  # positions in the *full* token list
                        }
                    )

            occs.append(
                {
                    "match_atom_indices": list(match_atoms),
                    "mapped": {
                        "atom_token_indices": mapped_atom_toks,
                        "substrings": mapped_substrings,
                    },
                    "original": {"segments": original_segments},
                    "fragment_smiles": frag_smiles,
                }
            )
        out["fragments"][name] = occs

    return out


# ──────────────────────────────────────────────────────────────────────────────
#                               EXAMPLE (optional)
# ──────────────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
    smarts = {
        "phenyl": "c1ccccc1",
        "ester_strict": "[CX3](=O)[OX2H0][#6]",
        "carboxylic_acid": "[CX3](=O)[OX2H1]",
    }
    queries = compile_smarts(smarts)

    examples = [
        "CC(=O)OC1=CC=CC=C1C(=O)O",
        "NCC(=O)OC",
        "Cc1ccccc1N",
        "O=C(O)c1ccccc1",
    ]

    print("Run this with your LM tokenizer to see token ids/texts per segment.")
    for s in examples:
        res = find_fragments_and_tokens(s, queries, tokenizer=None)
        print("\nSMILES:", res["input_smiles"])
        for fname, occs in res["fragments"].items():
            print(f"  Fragment: {fname}  (n={len(occs)})")
            for occ in occs:
                print("    mapped substrings:", occ["mapped"]["substrings"])
