from collections import Counter, defaultdict
from typing import Iterable, Dict, List
import numpy as np
import pickle
import gzip
import pathlib

try:
    profile
except NameError:
    def profile(func):
        return func
    
def train_ngram(
    corpus: Iterable[str],
    sym_to_id: Dict[str, int],
    vocab_size: int,
    max_order: int = 5,
    discount: float = 0.75
):
    assert max_order >= 1
    PAD = None
    counts: List[Dict] = [defaultdict(Counter) for _ in range(max_order)]
    totals: List[Dict] = [defaultdict(int) for _ in range(max_order)]
    distincts: List[Dict] = [defaultdict(int) for _ in range(max_order)]
    cont_cnt= Counter()

    ids = [sym_to_id[s] for s in corpus]
    window = [PAD] * (max_order - 1) + ids

    for i in range(max_order - 1, len(window)):
        for n in range(1, max_order + 1):
            ctx_start = i - (n - 1)
            ctx = tuple(window[ctx_start:i])
            sid = window[i]

            counts[n-1][ctx][sid] += 1
            totals[n-1][ctx]      += 1
            if sid not in counts[n-1][ctx] or counts[n-1][ctx][sid] == 1:
                distincts[n-1][ctx] += 1
        cont_cnt[sid] += 1

    total_cont = sum(cont_cnt.values())
    p_cont = {sid: c / total_cont for sid, c in cont_cnt.items()}

    return dict(
        counts=counts,
        totals=totals,
        distincts=distincts,
        p_cont=p_cont,
        vocab_size=vocab_size,
        max_order=max_order,
        D=float(discount),
    )

def prepare_ngram_arrays(nd):
    """Add the NumPy arrays that the fast n‑gram routine needs."""
    nd["p_cont_ids"] = np.fromiter(nd["p_cont"].keys(),   np.int32)
    nd["p_cont_vals"] = np.fromiter(nd["p_cont"].values(), np.float32)
    max_order = nd["max_order"]
    nd["ctx_to_idx"]  = [dict() for _ in range(max_order)]
    nd["counts_ids"]  = [dict() for _ in range(max_order)]
    nd["counts_vals"] = [dict() for _ in range(max_order)]
    for n in range(max_order):
        for i, (subctx, cnts) in enumerate(nd["counts"][n].items()):
            nd["ctx_to_idx"][n][subctx] = i
            nd["counts_ids"][n][i] = np.fromiter(cnts.keys(),   np.int32)
            nd["counts_vals"][n][i] = np.fromiter(cnts.values(), np.int32)
    return nd

@profile
def ngram_probs(
    ctx: tuple[int, ...],
    nd,
    alpha: float = 1e-5,
    scratch: np.ndarray | None = None
) -> np.ndarray:

    V = nd["vocab_size"]
    max_order = nd["max_order"]
    D = nd["D"]
    if scratch is None or scratch.shape[0] != V:
        scratch = np.empty(V, dtype=np.float32)
    scratch.fill(alpha / V)
    scratch[nd["p_cont_ids"]] += (1.0 - alpha) * nd["p_cont_vals"]
    ctx = ctx[-(max_order - 1):] if max_order > 1 else ()
    for n in range(1, max_order + 1):
        subctx = ctx[-(n - 1):] if n > 1 else ()
        ctx_idx = nd["ctx_to_idx"][n - 1].get(subctx)
        if ctx_idx is None:
            continue
        ids= nd["counts_ids"][n - 1][ctx_idx]
        cnts= nd["counts_vals"][n - 1][ctx_idx]
        total = nd["totals"][n - 1][subctx]
        Z = nd["distincts"][n - 1][subctx]
        lam = (D * Z) / total
        cnts_f32 = cnts.astype(np.float32, copy=False)
        np.subtract(cnts_f32, D, out=cnts_f32)
        np.maximum(cnts_f32, 0, out=cnts_f32)
        tmp = cnts_f32 / total
        scratch[ids] = tmp + lam * scratch[ids]
        
    scratch /= scratch.sum(dtype=np.float64)
    return scratch.copy()

def _save_ngram(obj: dict, path: pathlib.Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    with gzip.open(path, "wb") as f:
        pickle.dump(obj, f, protocol=pickle.HIGHEST_PROTOCOL)
    print(f"ngram written to {path} ("
          f"{path.stat().st_size/1024:.1f} KB)")

def _load_ngram(path: pathlib.Path) -> dict:
    with gzip.open(path, "rb") as f:
        print(f"ngram loaded from {path}")
        return pickle.load(f)

