import numpy as np
from poset import topo_sort


def is_edge(P, a, b) -> bool:
    return (b in P["succs"][a])


def max_stats_prefix(P, lamb, perm, k, swap_with=None, in_prefix=None, prefix_buf=None):
    succs = P["succs"]
    if swap_with is None:
        prefix = perm[:k + 1]
    else:
        if k > 0:
            prefix_buf[:k] = perm[:k]
        prefix_buf[k] = int(swap_with)
        prefix = prefix_buf[:k + 1]

    in_prefix[:] = False
    for v in prefix:
        in_prefix[int(v)] = True

    cnt = 0.0
    lam_sum = 0.0
    for v in prefix:
        v = int(v)
        has_succ = False
        for w in succs[v]:
            if in_prefix[int(w)]:
                has_succ = True
                break
        if not has_succ:
            cnt += 1.0
            lam_sum += float(lamb[v])

    if cnt <= 0.0:
        cnt = 1.0
    if lam_sum <= 0.0:
        lam_sum = 1e-12
    return cnt, lam_sum


def make_linear_extension_sampler(
    *,
    P,
    lamb,
    seed: int = 0,
    laziness: float = 0.1,
    burnin_steps: int = 20000,
    steps_between: int = 1000,
):
    rng = np.random.RandomState(seed)
    n = P["n"]
    perm0 = topo_sort(P, rng)
    pos0 = np.empty(n, dtype=np.int64)
    pos0[perm0] = np.arange(n, dtype=np.int64)

    in_prefix = np.zeros(n, dtype=bool)
    prefix_buf = np.empty(n, dtype=np.int64)
    uniform_lamb = bool(np.allclose(lamb, lamb[0]))

    state = {
        "P": P,
        "lamb": lamb,
        "perm": perm0,
        "pos": pos0,
        "rng": rng,
        "in_prefix": in_prefix,
        "prefix_buf": prefix_buf,
        "laziness": float(laziness),
        "uniform_lamb": uniform_lamb,
        "total_steps": 0,
        "total_moves": 0,
        "steps_between": int(steps_between),
    }

    def mcmc_step():
        perm = state["perm"]
        pos = state["pos"]
        rng = state["rng"]
        lamb = state["lamb"]
        laziness = state["laziness"]
        uniform = state["uniform_lamb"]
        in_pref = state["in_prefix"]
        pref_buf = state["prefix_buf"]
        n = len(perm)

        if laziness > 0.0 and rng.rand() < laziness:
            return False

        k = int(rng.randint(0, n - 1))
        a = int(perm[k])
        b = int(perm[k + 1])

        if is_edge(P, a, b):
            return False

        if uniform:
            perm[k], perm[k + 1] = perm[k + 1], perm[k]
            pos[a], pos[b] = k + 1, k
            return True

        cnt, lsum = max_stats_prefix(P, lamb, perm, k, None, in_pref, pref_buf)
        cntp, lsump = max_stats_prefix(P, lamb, perm, k, b, in_pref, pref_buf)

        r = (cntp / cnt) * (lsum / lsump)
        if not np.isfinite(r):
            r = 0.0

        if rng.rand() < min(1.0, float(r)):
            perm[k], perm[k + 1] = perm[k + 1], perm[k]
            pos[a], pos[b] = k + 1, k
            return True

        return False

    def mcmc_take(steps: int):
        moves = 0
        for _ in range(int(steps)):
            if mcmc_step():
                moves += 1
        state["total_steps"] += int(steps)
        state["total_moves"] += int(moves)

    if int(burnin_steps) > 0:
        mcmc_take(int(burnin_steps))

    def draw_perm():
        mcmc_take(state["steps_between"])
        return np.asarray(state["perm"], dtype=np.int64).copy()

    return draw_perm, state









