"""SMILES and SELFIES tokenization for molecular sequence models.

This module provides tokenization utilities for converting SMILES and SELFIES
strings into sequences of tokens suitable for sequence-based VAE models.
"""

import json
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Union

SMILES_TOKEN_PATTERN = re.compile(
    r"(\[[^\]]+\])"
    r"|Br|Cl"
    r"|Si|Se|Na|Li|Mg|Al|Ca|Fe|Zn|Cu|Mn|Hg|Ag|Au|Sn|Pb|Bi|As"
    r"|\%\d{2}"
    r"|\d"
    r"|=|#|-|\+|\\|/|:|\."
    r"|\(|\)"
    r"|@@|@"
    r"|[A-Za-z]"
    r"|\*"  # edge case pattern
    r"|\$"  # end marker in edge cases
)

SPECIAL_TOKENS = ["<pad>", "<bos>", "<eos>", "<unk>"]


def _normalize_representation(representation: str) -> str:
    """Normalize representation name.

    Supported:
      - "smiles"
      - "selfies"
    """
    rep = (representation or "smiles").strip().lower()
    if rep not in {"smiles", "selfies"}:
        raise ValueError(
            f"Unknown representation '{representation}'. Expected 'smiles' or 'selfies'."
        )
    return rep


def _require_selfies():
    """Import selfies lazily so SMILES-only usage doesn't require the dependency."""
    try:
        import selfies as sf  # type: ignore
    except Exception as e:  # pragma: no cover
        raise ImportError(
            "SELFIES support requires the `selfies` package. "
            "Install it (e.g. `uv add selfies` or `pip install selfies`)."
        ) from e
    return sf


@dataclass
class Vocab:
    token_to_id: Dict[str, int]
    id_to_token: List[str]

    @property
    def pad_id(self) -> int:
        return self.token_to_id["<pad>"]

    @property
    def bos_id(self) -> int:
        return self.token_to_id["<bos>"]

    @property
    def eos_id(self) -> int:
        return self.token_to_id["<eos>"]

    @property
    def unk_id(self) -> int:
        return self.token_to_id["<unk>"]

    def save(self, path: Union[str, Path]) -> None:
        """Save vocabulary to a JSON file."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, "w") as f:
            json.dump({"id_to_token": self.id_to_token}, f, indent=2)

    @classmethod
    def load(cls, path: Union[str, Path]) -> "Vocab":
        """Load vocabulary from a JSON file."""
        path = Path(path)
        with open(path) as f:
            data = json.load(f)
        id_to_token = data["id_to_token"]
        token_to_id = {t: i for i, t in enumerate(id_to_token)}
        return cls(token_to_id=token_to_id, id_to_token=id_to_token)


def tokenize_smiles(smiles: str) -> List[str]:
    """Tokenize a SMILES string into a list of tokens.

    Uses a regex pattern to identify SMILES tokens including:
    - Atom symbols in brackets (e.g., [C], [N+])
    - Multi-character atoms (Br, Cl)
    - Special atoms (Si, Se, etc.)
    - Bonds (=, #, -, +, \\, /)
    - Rings and branches (numbers, parentheses)
    - Stereochemistry markers (@@, @)

    Args:
        smiles: Input SMILES string

    Returns:
        List of token strings

    Example:
        >>> tokenize_smiles("CCO")
        ['C', 'C', 'O']
        >>> tokenize_smiles("c1ccccc1")
        ['c', '1', 'c', 'c', 'c', 'c', 'c', '1']
    """
    toks = []
    i = 0
    while i < len(smiles):
        m = SMILES_TOKEN_PATTERN.match(smiles, i)
        if m is None:
            # fallback: treat single char as unknown token
            toks.append(smiles[i])
            i += 1
        else:
            tok = m.group(0)
            toks.append(tok)
            i = m.end()
    return toks


def tokenize_selfies(selfies: str) -> List[str]:
    """Tokenize a SELFIES string.

    Notes:
      - SELFIES tokens are bracketed, so tokenization is unambiguous.
      - Requires the optional `selfies` dependency.
    """
    sf = _require_selfies()
    return list(sf.split_selfies(selfies))


def smiles_to_selfies(smiles: str) -> Optional[str]:
    """Convert SMILES -> SELFIES.

    Returns None if encoding fails.
    """
    sf = _require_selfies()
    try:
        return sf.encoder(smiles)
    except Exception:
        return None


def selfies_to_smiles(selfies: str) -> Optional[str]:
    """Convert SELFIES -> SMILES.

    Returns None if decoding fails.
    """
    sf = _require_selfies()
    try:
        return sf.decoder(selfies)
    except Exception:
        return None


def detokenize(tokens: List[str]) -> str:
    """Detokenize tokens back into a string.

    For both SMILES and SELFIES, detokenization is a simple concatenation.
    """
    return "".join(tokens)


def build_vocab(
    sequences: Iterable[str],
    min_freq: int = 1,
    representation: str = "smiles",
) -> Vocab:
    """Build a vocabulary from a collection of sequences.

    Creates a vocabulary by:
    1. Tokenizing all sequences
    2. Counting token frequencies
    3. Adding special tokens (<pad>, <bos>, <eos>, <unk>)
    4. Adding tokens that meet the minimum frequency threshold

    Args:
        sequences: Iterable of SMILES or SELFIES strings
        min_freq: Minimum frequency for a token to be included (default: 1)
        representation: "smiles" or "selfies" (determines tokenization method)

    Returns:
        Vocab object with token_to_id and id_to_token mappings

    Example:
        >>> sequences = ["CCO", "CCCO", "CCO"]
        >>> vocab = build_vocab(sequences, min_freq=1)
        >>> len(vocab.id_to_token)
        7  # 4 special tokens + 3 unique tokens (C, C, O)
    """
    from collections import Counter

    rep = _normalize_representation(representation)
    tok_fn = tokenize_smiles if rep == "smiles" else tokenize_selfies

    c = Counter()
    for s in sequences:
        c.update(tok_fn(s))
    tokens = list(SPECIAL_TOKENS)
    for tok, freq in sorted(c.items(), key=lambda x: (-x[1], x[0])):
        if freq >= min_freq and tok not in tokens:
            tokens.append(tok)
    token_to_id = {t: i for i, t in enumerate(tokens)}
    return Vocab(token_to_id=token_to_id, id_to_token=tokens)


def encode(
    sequence: str,
    vocab: Vocab,
    max_len: int,
    representation: str = "smiles",
) -> List[int]:
    """Encode a sequence string into a list of token IDs.

    Encodes a SMILES or SELFIES string by:
    1. Tokenizing the sequence
    2. Converting tokens to IDs using the vocabulary
    3. Adding BOS and EOS tokens
    4. Padding or truncating to max_len

    Args:
        sequence: SMILES or SELFIES string to encode
        vocab: Vocabulary object with token mappings
        max_len: Maximum sequence length (includes BOS/EOS)
        representation: "smiles" or "selfies" (determines tokenization)

    Returns:
        List of token IDs of length max_len, with format:
        [BOS, token_ids..., EOS, PAD, PAD, ...]

    Example:
        >>> vocab = build_vocab(["CCO"], min_freq=1)
        >>> encode("CCO", vocab, max_len=10)
        [1, 4, 4, 5, 2, 0, 0, 0, 0, 0]  # BOS=1, C=4, O=5, EOS=2, PAD=0
    """
    rep = _normalize_representation(representation)
    tok_fn = tokenize_smiles if rep == "smiles" else tokenize_selfies
    toks = tok_fn(sequence)
    ids = [vocab.bos_id] + [vocab.token_to_id.get(t, vocab.unk_id) for t in toks] + [vocab.eos_id]
    if len(ids) > max_len:
        # keep BOS and EOS, truncate middle
        ids = ids[: max_len - 1] + [vocab.eos_id]
    if len(ids) < max_len:
        ids = ids + [vocab.pad_id] * (max_len - len(ids))
    return ids


def encode_smiles(smiles: str, vocab: Vocab, max_len: int) -> List[int]:
    """Backward-compatible wrapper for SMILES encoding."""
    return encode(smiles, vocab, max_len, representation="smiles")


def encode_selfies(selfies: str, vocab: Vocab, max_len: int) -> List[int]:
    """Encode a SELFIES string (requires the optional `selfies` dependency)."""
    return encode(selfies, vocab, max_len, representation="selfies")


def decode_ids(ids: List[int], vocab: Vocab, representation: str = "smiles") -> str:
    """Decode a list of token IDs back into a sequence string.

    Converts token IDs to tokens, filters out special tokens (PAD, BOS, EOS, UNK),
    and concatenates them into a SMILES or SELFIES string.

    Args:
        ids: List of token IDs (typically from model output)
        vocab: Vocabulary object with ID to token mappings
        representation: "smiles" or "selfies" (for future extensibility)

    Returns:
        Decoded SMILES or SELFIES string

    Example:
        >>> vocab = build_vocab(["CCO"], min_freq=1)
        >>> ids = [1, 4, 4, 5, 2, 0, 0]  # BOS, C, C, O, EOS, PAD, PAD
        >>> decode_ids(ids, vocab)
        "CCO"
    """
    # representation currently only influences tokenization elsewhere, but we keep the
    # argument so call-sites can be explicit and future-proof.
    _ = _normalize_representation(representation)
    toks = []
    for i in ids:
        t = vocab.id_to_token[int(i)]
        if t in ("<pad>", "<bos>"):
            continue
        if t == "<eos>":
            break
        if t == "<unk>":
            # keep as empty; decoder should learn not to emit it frequently
            continue
        toks.append(t)
    return detokenize(toks)
