import numpy as np
from poset import is_edge, topo_sort, max_stats_prefix, group_max_stats_from_counts


def make_mcmc_state(
    P,
    lamb,
    rng=None,
    laziness=0.1,
    use_prefix_cache=False,
    init_perm=None,
):
    if rng is None:
        rng = np.random
    lamb = np.asarray(lamb, dtype=float)
    n = P["n"]

    if init_perm is not None:
        perm0 = np.asarray(init_perm, dtype=np.int64)
        if perm0.shape[0] != n:
            raise ValueError(f"init_perm length {perm0.shape[0]} != n {n}")
    else:
        perm0 = topo_sort(P, rng)
    pos0 = np.empty(n, dtype=np.int64)
    pos0[perm0] = np.arange(n, dtype=np.int64)

    in_prefix = np.zeros(n, dtype=bool)
    prefix_buf = np.empty(n, dtype=np.int64)
    uniform_lamb = bool(np.allclose(lamb, lamb[0]))

    state = {
        "P": P,
        "lamb": lamb,
        "perm": perm0,
        "pos": pos0,
        "rng": rng,
        "in_prefix": in_prefix,
        "prefix_buf": prefix_buf,
        "laziness": float(laziness),
        "uniform_lamb": uniform_lamb,
        "total_steps": 0,
        "total_moves": 0,
        "group_ids": None,
        "group_succ": None,
        "num_groups": None,
        "grp_cnt_prefix": None,
        "grp_lam_prefix": None,
    }

    group_ids = P.get("group_ids", None)
    group_succ = P.get("group_succ", None)

    if use_prefix_cache and group_ids is not None and group_succ is not None:
        group_ids = np.asarray(group_ids, dtype=np.int64)
        num_groups = int(group_ids.max()) + 1

        grp_cnt_prefix = np.zeros((num_groups, n), dtype=np.int64)
        grp_lam_prefix = np.zeros((num_groups, n), dtype=float)

        cnt_run = np.zeros(num_groups, dtype=np.int64)
        lam_run = np.zeros(num_groups, dtype=float)

        for j in range(n):
            v = int(perm0[j])
            gid = int(group_ids[v])
            lam_v = float(lamb[v])
            cnt_run[gid] += 1
            lam_run[gid] += lam_v
            grp_cnt_prefix[:, j] = cnt_run
            grp_lam_prefix[:, j] = lam_run

        state["group_ids"] = group_ids
        state["group_succ"] = group_succ
        state["num_groups"] = num_groups
        state["grp_cnt_prefix"] = grp_cnt_prefix
        state["grp_lam_prefix"] = grp_lam_prefix

    return state


def tilt_mcmc_step(state):
    P = state["P"]
    lamb = state["lamb"]
    perm = state["perm"]
    pos = state["pos"]
    rng = state["rng"]
    in_pref = state["in_prefix"]
    pref_buf = state["prefix_buf"]
    laziness = state["laziness"]
    uniform_lamb = state["uniform_lamb"]

    group_ids = state.get("group_ids", None)
    group_succ = state.get("group_succ", None)
    grp_cnt_prefix = state.get("grp_cnt_prefix", None)
    grp_lam_prefix = state.get("grp_lam_prefix", None)

    n = len(perm)

    if laziness > 0.0 and rng.rand() < float(laziness):
        return False

    k = int(rng.randint(0, n - 1))
    a = int(perm[k])
    b = int(perm[k + 1])

    if is_edge(P, a, b):
        return False

    if uniform_lamb:
        perm[k], perm[k + 1] = perm[k + 1], perm[k]
        pos[a], pos[b] = k + 1, k
        return True

    use_prefix_cache = (
        group_ids is not None
        and group_succ is not None
        and grp_cnt_prefix is not None
        and grp_lam_prefix is not None
    )

    if use_prefix_cache:
        cnt_col = grp_cnt_prefix[:, k]
        lam_col = grp_lam_prefix[:, k]
        cnt, lsum = group_max_stats_from_counts(
            cnt_col,
            lam_col,
            group_succ,
        )
        cnt_p = cnt_col.copy()
        lam_p = lam_col.copy()
        gid_a = int(group_ids[a])
        gid_b = int(group_ids[b])
        cnt_p[gid_a] -= 1
        lam_p[gid_a] -= float(lamb[a])
        cnt_p[gid_b] += 1
        lam_p[gid_b] += float(lamb[b])
        cntp, lsump = group_max_stats_from_counts(
            cnt_p,
            lam_p,
            group_succ,
        )
    else:
        cnt, lsum = max_stats_prefix(P, lamb, perm, pos, k, None, in_pref, pref_buf)
        cntp, lsump = max_stats_prefix(P, lamb, perm, pos, k, b, in_pref, pref_buf)

    r = (cntp / cnt) * (lsum / lsump)
    if not np.isfinite(r):
        r = 0.0

    if rng.rand() < min(1.0, float(r)):
        perm[k], perm[k + 1] = perm[k + 1], perm[k]
        pos[a], pos[b] = k + 1, k
        if use_prefix_cache:
            grp_cnt_prefix[:, k] = cnt_p
            grp_lam_prefix[:, k] = lam_p
        return True
    return False


def mcmc_take(state, steps):
    moves = 0
    for _ in range(int(steps)):
        if tilt_mcmc_step(state):
            moves += 1
    state["total_steps"] += int(steps)
    state["total_moves"] += int(moves)
    return np.asarray(state["perm"], dtype=np.int64)


